From 6774af43c24801aa706654ea3679385913272176 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Thu, 3 Sep 2020 10:13:47 -0700 Subject: [PATCH 1/6] Implement horizontal input fusion. Extend horizontal fusion to support fusion of reduction instructions. --- tensorflow/compiler/xla/service/gpu/BUILD | 43 +++++ .../compiler/xla/service/gpu/gpu_compiler.cc | 2 + .../compiler/xla/service/gpu/gpu_fusible.cc | 63 ++++--- .../compiler/xla/service/gpu/gpu_fusible.h | 9 + .../service/gpu/horizontal_input_fusion.cc | 157 ++++++++++++++++++ .../xla/service/gpu/horizontal_input_fusion.h | 57 +++++++ .../gpu/horizontal_input_fusion_test.cc | 146 ++++++++++++++++ .../xla/service/gpu/horizontal_loop_fusion.cc | 20 +-- 8 files changed, 455 insertions(+), 42 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc create mode 100644 tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h create mode 100644 tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b2ec656a2ba..39b32ec2d1d 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1187,6 +1187,7 @@ cc_library( ":gpu_sanitize_constant_names", ":gpu_scatter_expander", ":horizontal_loop_fusion", + ":horizontal_input_fusion", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -1769,6 +1770,7 @@ cc_library( srcs = ["horizontal_loop_fusion.cc"], hdrs = ["horizontal_loop_fusion.h"], deps = [ + ":gpu_fusible", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", @@ -1805,6 +1807,47 @@ tf_cc_test( ], ) +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + ":gpu_fusible", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + deps = [ + ":fusion_merger", + ":horizontal_input_fusion", + ":instruction_fusion", + ":multi_output_fusion", + "//tensorflow/compiler/jit:xla_gpu_jit", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "reduction_degenerate_dim_remover", srcs = ["reduction_degenerate_dim_remover.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index e9435f4fa92..feedff0e0b3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -59,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -306,6 +307,7 @@ Status GpuCompiler::OptimizeHloModule( HloPassPipeline horizontal_fusion("horizontal_fusion"); horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index ce319b4c59d..9f8c3c81ad2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } +const HloInstruction* GetMajorNodeForMultiOutputFusion( + const HloInstruction& instr) { + if (instr.opcode() != HloOpcode::kFusion) { + return &instr; + } + auto fused_expression_root = instr.fused_expression_root(); + if (!instr.IsMultiOutputFusion()) { + return fused_expression_root; + } + // If possible, we want to pick a reduction-to-vector operand of the + // fusion root, because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (IsReductionFromOrToContiguousDimensions(*inst)) { + return inst; + } + } + return fused_expression_root->operands()[0]; +} + bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, const HloInstruction& instr2) { - // Returns the instructions that determines the emitter used for lowering, - // sometimes referred to as "the real hero". - auto get_real_hero = - [&](const HloInstruction* instr) -> const HloInstruction* { - if (instr->opcode() != HloOpcode::kFusion) { - return instr; - } - auto fused_expression_root = instr->fused_expression_root(); - if (!instr->IsMultiOutputFusion()) { - return fused_expression_root; - } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. - for (const auto* inst : fused_expression_root->operands()) { - if (IsReductionFromOrToContiguousDimensions(*inst)) { - return inst; - } - } - return fused_expression_root->operands()[0]; - }; - // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = get_real_hero(&instr1); - auto* instr_2 = get_real_hero(&instr2); + auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1); + auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -524,5 +522,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/, : HloInstruction::FusionKind::kLoop; } +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer) { + return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + // Skip GTE. + return IsConsumerTheOnlyNonRootUser(*user, consumer); + } else if (user == &consumer) { + // `user` is `consumer`. + return true; + } else if (user == user->parent()->root_instruction()) { + // Consumed by ROOT is always fine, since it is impossible to create + // cycles through ROOT. + return true; + } else { + return false; + } + }); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index e7cac6e55c8..8595bb24ddf 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -71,6 +71,11 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1, bool CreatesNestedLoop(const HloInstruction& producer, const HloInstruction& consumer); +// Returns the instruction that determines the emitter used for lowering, +// sometimes referred to as "the real hero". +const HloInstruction* GetMajorNodeForMultiOutputFusion( + const HloInstruction& instr); + // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. // This function works for both, sibling and producer-consumer multi-output @@ -100,6 +105,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, const HloInstruction& consumer); +// Returns whether `consumer` is the only non-root user of `instr`. +bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, + const HloInstruction& consumer); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc new file mode 100644 index 00000000000..75a69611780 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -0,0 +1,157 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/core/platform/errors.h" + +namespace xla { +namespace gpu { + +namespace { + +// Gets the representative input shape of the multi-output fusion. +Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { + // Get the major node used in the emitter. + const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr); + if (real_hero->operands().empty()) { + // Simply return an empty shape if the representative node has no input + // operands. + return Shape(); + } else { + return real_hero->operand(0)->shape(); + } +} + +class HorizontalInputFusionImpl { + public: + explicit HorizontalInputFusionImpl(HloComputation* computation) + : computation_(computation) {} + + ~HorizontalInputFusionImpl() {} + + StatusOr Run(); + + private: + HloComputation* computation_; +}; // HorizontalInputFusionImpl + +std::vector FindAndSortFusionCandidates( + HloInstruction* consumer) { + absl::flat_hash_set fusion_instr_set; + for (auto opnd : consumer->operands()) { + HloInstruction* predecessor = opnd->LatestNonGteAncestor(); + // Find out the input fusion instructions whose only consumer is `consumer`. + // This guarantees that fusing these candidates will never create cycles, as + // there is no back edge. + if (IsReduceInputFusion(*predecessor) && + IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { + fusion_instr_set.insert(predecessor); + } + } + + std::vector fusion_instrs; + fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(), + fusion_instr_set.end()); + + std::sort(fusion_instrs.begin(), fusion_instrs.end(), + [&](const HloInstruction* a, const HloInstruction* b) { + Shape shape_a = GetInputShapeForMultiOutputFusion(*a); + Shape shape_b = GetInputShapeForMultiOutputFusion(*b); + if (shape_a.rank() != shape_b.rank()) { + return shape_a.rank() < shape_b.rank(); + } else if (ShapeUtil::ElementsIn(shape_a) != + ShapeUtil::ElementsIn(shape_b)) { + // Sort according to element size so that roughly the same input + // shape will be placed adjacent. + return ShapeUtil::ElementsIn(shape_a) < + ShapeUtil::ElementsIn(shape_b); + } else { + // Sort `fusion_instrs` according to instruction counts, because + // we'd like to fuse together computations of similar sizes. + return a->fused_instruction_count() < + b->fused_instruction_count(); + } + }); + + return fusion_instrs; +} + +StatusOr HorizontalInputFusionImpl::Run() { + bool changed = false; + XLA_VLOG_LINES(3, computation_->ToString()); + + // Using def-to-use order is sound since we do not modify users. + std::vector def_to_use_order = + computation_->MakeInstructionPostOrder(); + for (size_t i = 0; i < def_to_use_order.size(); ++i) { + auto consumer = def_to_use_order[i]; + auto candidates = FindAndSortFusionCandidates(consumer); + if (candidates.empty()) { + continue; + } + + size_t fusion_anchor_id = 0; + for (size_t j = 1; j < candidates.size(); ++j) { + HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; + HloInstruction* fused = candidates[j]; + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { + VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ", + fusion_anchor->ToString()); + fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); + changed = true; + } else { + // Update the `fusion_anchor_id` since `fused` is either not + // compatible or not beneficial to be fused with current fusion anchor. + VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1, + " instructions are fused"); + fusion_anchor_id = j; + } + } + } + + return changed; +} + +} // namespace + +StatusOr GpuHorizontalInputFusion::RunOnComputation( + HloComputation* computation) { + HorizontalInputFusionImpl horizontal_fusion_impl(computation); + return horizontal_fusion_impl.Run(); +} + +StatusOr GpuHorizontalInputFusion::Run(HloModule* module) { + bool changed = false; + VLOG(2) << "Run horizontal input fusion."; + for (auto* comp : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp)); + } + + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h new file mode 100644 index 00000000000..85313d03412 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace gpu { + +// This optimization pass horizontally fuses kInput fusions to both reduce the +// kernel launch overhead and increase parallelism degree. See +// GpuHorizontalFusion for general description and motivation about horizontal +// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// with kInput fusions. +// +// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// to search the fusion candidates while avoiding creating cycles. That is, +// we simply search for fusion candidates by looking for instructions whose +// outputs are all consumed by the same instruction. This catches the typical +// target cases; often, the candidate instructions are just consumed by the +// ROOT tuple of the entry computation. +class GpuHorizontalInputFusion : public HloModulePass { + public: + GpuHorizontalInputFusion() {} + + absl::string_view name() const override { + return "gpu_horizontal_input_fusion"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation*); +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc new file mode 100644 index 00000000000..035658fe55e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HorizontalInputFusionTest : public GpuCodegenTest {}; + +TEST_F(HorizontalInputFusionTest, BasicTest) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule BasicTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2 + ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2) + } +)").ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); + + const HloInstruction* entry_root = + module->entry_computation()->root_instruction(); + EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())), + (op::GetTupleElement(op::Fusion())))); + + const HloInstruction* fusion = entry_root->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(HorizontalInputFusionTest, ManyInputFusions) { + auto module = CreateNewVerifiedModule(); + + HloComputation* reduce_computation; + { + auto embedded_builder = HloComputation::Builder("add"); + auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "lhs")); + auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "rhs")); + embedded_builder.AddInstruction( + HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs)); + reduce_computation = + module->AddEmbeddedComputation(embedded_builder.Build()); + } + + HloComputation::Builder builder(TestName()); + std::vector var_outs; + auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024}); + auto output_shape = ShapeUtil::MakeShape(F32, {1024}); + for (int64 i = 0; i < 130; ++i) { + //%fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) -> + // f32[1024] { + // %param_0 = f32[1024,1024]{1,0} parameter(0) + // %param_1 = f32[] parameter(1) + // %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1), + // dimensions={} + // %multiply = f32[1024,1024]{1,0} + // multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0} + // %broadcast) + // %constant0 = f32[] constant(0) + // ROOT %reduce = f32[1024]{0} + // reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0), + // dimensions={1}, to_apply=%add + //} + HloInstruction* param_var_in = builder.AddInstruction( + HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in")); + HloInstruction* param_alpha = + builder.AddInstruction(HloInstruction::CreateParameter( + i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha")); + auto alpha_broadcasted = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, param_alpha, {})); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted)); + HloInstruction* const0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( + output_shape, mul, const0, {1}, reduce_computation)); + var_outs.push_back(reduce); + } + builder.AddInstruction(HloInstruction::CreateTuple(var_outs)); + module->AddEntryComputation(builder.Build()); + + // Verify that horizontal fusion is kicked in. Check that there are multiple + // `reduce` instructions fused into the same fusion. 6 is just a randomly + // picked number as we don't exactly know how large the fusion will be + // created. + CompileAndVerifyIr(module->Clone(), + R"(CHECK: reduce-group-6)", + /*match_optimized_ir=*/false); + + // Testing with the entire gpu optimization pipeline. + EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index 577c7eed6c4..9d1e0533a91 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/env_var.h" @@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) { return true; } -bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, - const HloInstruction& consumer) { - return absl::c_all_of(instr.users(), [&](const HloInstruction* user) { - if (user->opcode() == HloOpcode::kGetTupleElement) { - // Skip GTE. - return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { - // `user` is `consumer`. - return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; - } - }); -} - // Returns whether `instr` is a profitable candidate to be horizontally fused. // Since the primary benefit of horizontal fusion comes from reducing the // kernel launch overhead, we want to exclude the instructions with From ea0b5fa33f6ed23039e9f580eaec8bd69d52ec82 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 9 Sep 2020 13:25:52 -0700 Subject: [PATCH 2/6] [XLA/GPU] Address review comments. --- tensorflow/compiler/xla/service/gpu/BUILD | 17 +++++-------- .../compiler/xla/service/gpu/gpu_fusible.cc | 24 +++++++++---------- .../compiler/xla/service/gpu/gpu_fusible.h | 2 +- .../service/gpu/horizontal_input_fusion.cc | 10 ++++---- .../gpu/horizontal_input_fusion_test.cc | 2 +- 5 files changed, 24 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 39b32ec2d1d..39dad267acf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1814,10 +1814,10 @@ cc_library( deps = [ ":gpu_fusible", ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_set", @@ -1827,22 +1827,17 @@ cc_library( tf_cc_test( name = "horizontal_input_fusion_test", srcs = ["horizontal_input_fusion_test.cc"], + tags = tf_cuda_tests_tags(), deps = [ - ":fusion_merger", ":horizontal_input_fusion", - ":instruction_fusion", ":multi_output_fusion", - "//tensorflow/compiler/jit:xla_gpu_jit", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", - "//tensorflow/compiler/xla/service:hlo_dce", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 9f8c3c81ad2..b69b32c17c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } -const HloInstruction* GetMajorNodeForMultiOutputFusion( +const HloInstruction* GetRealHeroForMultiOutputFusion( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kFusion) { return &instr; @@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion( if (!instr.IsMultiOutputFusion()) { return fused_expression_root; } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root, because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { if (IsReductionFromOrToContiguousDimensions(*inst)) { return inst; @@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1); - auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2); + auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1); + auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, if (user->opcode() == HloOpcode::kGetTupleElement) { // Skip GTE. return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { + } + if (user == &consumer) { // `user` is `consumer`. return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; } + if (user == user->parent()->root_instruction()) { + // Consumed by ROOT. + return true; + } + return false; }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index 8595bb24ddf..9fa098a3394 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer, // Returns the instruction that determines the emitter used for lowering, // sometimes referred to as "the real hero". -const HloInstruction* GetMajorNodeForMultiOutputFusion( +const HloInstruction* GetRealHeroForMultiOutputFusion( const HloInstruction& instr); // Whether instruction shapes are compatible for multi-output fusion, i.e. diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc index 75a69611780..f25a283e4b9 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -33,8 +33,8 @@ namespace { // Gets the representative input shape of the multi-output fusion. Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { - // Get the major node used in the emitter. - const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr); + // Get the HLO that determines the emitter used for lowering. + const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); if (real_hero->operands().empty()) { // Simply return an empty shape if the representative node has no input // operands. @@ -118,15 +118,13 @@ StatusOr HorizontalInputFusionImpl::Run() { HloInstruction* fused = candidates[j]; if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { - VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ", - fusion_anchor->ToString()); + VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString(); fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); changed = true; } else { // Update the `fusion_anchor_id` since `fused` is either not // compatible or not beneficial to be fused with current fusion anchor. - VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1, - " instructions are fused"); + VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused."; fusion_anchor_id = j; } } diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc index 035658fe55e..f27e77fad68 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) { // Verify that horizontal fusion is kicked in. Check that there are multiple // `reduce` instructions fused into the same fusion. 6 is just a randomly // picked number as we don't exactly know how large the fusion will be - // created. + // created due to the `FusionWouldBeTooLarge` constraint. CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", /*match_optimized_ir=*/false); From 9e8d910e3bfde47ed2470080aff8ccf38886a2db Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 9 Sep 2020 16:21:30 -0700 Subject: [PATCH 3/6] [XLA/GPU] Revise the shape comparison function in horizontal_input_fusion. So that we can distinguish [128,256] and [256,128]. --- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../service/gpu/horizontal_input_fusion.cc | 38 ++++++++++++------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 39dad267acf..d81182f6036 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1831,6 +1831,7 @@ tf_cc_test( deps = [ ":horizontal_input_fusion", ":multi_output_fusion", + "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc index f25a283e4b9..75f12928067 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -57,6 +57,23 @@ class HorizontalInputFusionImpl { HloComputation* computation_; }; // HorizontalInputFusionImpl +// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to +// right. +bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, + const Shape& shape_b) { + if (shape_a.rank() != shape_b.rank()) { + return shape_a.rank() < shape_b.rank(); + } + auto dims_a = shape_a.dimensions(); + auto dims_b = shape_b.dimensions(); + for (size_t i = 0; i < dims_a.size(); ++i) { + if (dims_a[i] != dims_b[i]) { + return dims_a[i] < dims_b[i]; + } + } + return true; +} + std::vector FindAndSortFusionCandidates( HloInstruction* consumer) { absl::flat_hash_set fusion_instr_set; @@ -79,20 +96,15 @@ std::vector FindAndSortFusionCandidates( [&](const HloInstruction* a, const HloInstruction* b) { Shape shape_a = GetInputShapeForMultiOutputFusion(*a); Shape shape_b = GetInputShapeForMultiOutputFusion(*b); - if (shape_a.rank() != shape_b.rank()) { - return shape_a.rank() < shape_b.rank(); - } else if (ShapeUtil::ElementsIn(shape_a) != - ShapeUtil::ElementsIn(shape_b)) { - // Sort according to element size so that roughly the same input - // shape will be placed adjacent. - return ShapeUtil::ElementsIn(shape_a) < - ShapeUtil::ElementsIn(shape_b); - } else { - // Sort `fusion_instrs` according to instruction counts, because - // we'd like to fuse together computations of similar sizes. - return a->fused_instruction_count() < - b->fused_instruction_count(); + if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { + // Sort shapes according to dimensions, so that the same input + // shape will be placed adjacent each other. + return CompareShapeDimsFromLeftToRight(shape_a, shape_b); } + // Sort `fusion_instrs` according to instruction counts, because + // we'd like to fuse together computations of similar sizes. + return a->fused_instruction_count() < + b->fused_instruction_count(); }); return fusion_instrs; From 018fb5df02ac36da4638226e55fb9de9cd52ac08 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Wed, 9 Sep 2020 16:33:48 -0700 Subject: [PATCH 4/6] [XLA/GPU] minor comment polishing. --- tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc index 75f12928067..f0bd144c305 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -98,7 +98,7 @@ std::vector FindAndSortFusionCandidates( Shape shape_b = GetInputShapeForMultiOutputFusion(*b); if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { // Sort shapes according to dimensions, so that the same input - // shape will be placed adjacent each other. + // shapes will be placed adjacent each other. return CompareShapeDimsFromLeftToRight(shape_a, shape_b); } // Sort `fusion_instrs` according to instruction counts, because From 4c12d776dde7c8b1ea5ed4800ebfcf15ad5b16e2 Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Mon, 12 Oct 2020 18:06:27 -0700 Subject: [PATCH 5/6] [XLA/GPU] Add test MultiOutputFusionTest. --- .../gpu/horizontal_input_fusion_test.cc | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc index f27e77fad68..9ec07ce1a7e 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -141,6 +141,75 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); } +TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) { + // This tests the below pattern. One known issue is that gtes (to fusions) can + // be removed after their producer fusions are merged. In the below case, gte2 + // and gte6 will be gone if Fusion2 is fused into Fusion1. + // + // Fusion1 Fusion2 + // | | | | + // | gte1 gte2 | + // | | | | + // | Fusion3 | + // | | | | + // gte3 gte4 gte5 gte6 + // \ | | / + // =====ROOT===== + // + auto module = ParseAndReturnVerifiedModule(R"( + HloModule MultiOutputFusionTest + + %add_f16 { + %x = f16[] parameter(0) + %y = f16[] parameter(1) + ROOT %add = f16[] add(%x, %y) + } + + fused_computation.1 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.2 { + arg.1 = f16[1024]{0} parameter(0) + constant0 = f16[] constant(0) + reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16 + add.0 = f16[1024] add(arg.1, arg.1) + ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0) + } + + fused_computation.3 { + arg.0 = f16[1024]{0} parameter(0) + arg.1 = f16[1024]{0} parameter(1) + add.0 = f16[1024] add(arg.0, arg.1) + mul.0 = f16[1024] multiply(arg.0, arg.1) + ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0) + } + + ENTRY entry_computation { + arg.1 = f16[1024]{0} parameter(0) + arg.2 = f16[1024]{0} parameter(1) + fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1 + fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2 + gte.3 = f16[] get-tuple-element(fusion.1), index=0 + gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1 + gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1 + gte.6 = f16[] get-tuple-element(fusion.2), index=0 + fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2), + kind=kLoop, calls=fused_computation.3 + gte.4 = f16[1024] get-tuple-element(fusion.3), index=0 + gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1 + ROOT tuple.1 = (f16[], f16[1024]{0}, f16[], f16[1024]{0}) + tuple(gte.3, gte.4, gte.5, gte.6) + } +)").ValueOrDie(); + + EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla From aecc90e47c537704284410a6e92006e6c124c4fb Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Mon, 12 Oct 2020 18:55:09 -0700 Subject: [PATCH 6/6] [XLA/GPU] Formatting. --- tensorflow/compiler/xla/service/gpu/BUILD | 2 +- tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d81182f6036..bb3e09a2e76 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1186,8 +1186,8 @@ cc_library( ":gpu_layout_assignment", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", - ":horizontal_loop_fusion", ":horizontal_input_fusion", + ":horizontal_loop_fusion", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc index f0bd144c305..58ed9f18840 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -130,7 +130,8 @@ StatusOr HorizontalInputFusionImpl::Run() { HloInstruction* fused = candidates[j]; if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { - VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString(); + VLOG(3) << "Fuse " << fused->ToString() << " into " + << fusion_anchor->ToString(); fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); changed = true; } else {