diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 002bf9ea901..eebc885f87b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1153,8 +1153,8 @@ cc_library( ":gpu_layout_assignment", ":gpu_sanitize_constant_names", ":gpu_scatter_expander", - ":horizontal_fusion", ":horizontal_input_fusion", + ":horizontal_loop_fusion", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", @@ -1730,9 +1730,9 @@ tf_cc_test( ) cc_library( - name = "horizontal_fusion", - srcs = ["horizontal_fusion.cc"], - hdrs = ["horizontal_fusion.h"], + name = "horizontal_loop_fusion", + srcs = ["horizontal_loop_fusion.cc"], + hdrs = ["horizontal_loop_fusion.h"], deps = [ ":gpu_fusible", "//tensorflow/compiler/xla:shape_util", @@ -1747,11 +1747,11 @@ cc_library( ) tf_cc_test( - name = "horizontal_fusion_test", - srcs = ["horizontal_fusion_test.cc"], + name = "horizontal_loop_fusion_test", + srcs = ["horizontal_loop_fusion_test.cc"], deps = [ ":fusion_merger", - ":horizontal_fusion", + ":horizontal_loop_fusion", ":instruction_fusion", ":multi_output_fusion", "//tensorflow/compiler/jit:xla_gpu_jit", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index cc4de2c1099..571d7656861 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" -#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -59,8 +59,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h" #include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h" -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" @@ -302,7 +302,7 @@ Status GpuCompiler::OptimizeHloModule( TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline horizontal_fusion("horizontal_fusion"); - horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(); horizontal_fusion.AddPass(); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h index 85313d03412..71302317ec2 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h @@ -27,11 +27,11 @@ namespace gpu { // This optimization pass horizontally fuses kInput fusions to both reduce the // kernel launch overhead and increase parallelism degree. See -// GpuHorizontalFusion for general description and motivation about horizontal -// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals -// with kInput fusions. +// GpuHorizontalLoopFusion for general description and motivation about +// horizontal fusion. GpuHorizontalLoopFusion deals with kLoop fusions while +// this pass deals kInput fusions. // -// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// Following GpuHorizontalLoopFusion, a simple yet effective heuristic is used // to search the fusion candidates while avoiding creating cycles. That is, // we simply search for fusion candidates by looking for instructions whose // outputs are all consumed by the same instruction. This catches the typical diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc similarity index 96% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc index d11d1659d51..9d1e0533a91 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include @@ -67,12 +67,12 @@ PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) { return first_output_type; } -class HorizontalFusionImpl { +class HorizontalLoopFusionImpl { public: - explicit HorizontalFusionImpl(HloComputation* computation) + explicit HorizontalLoopFusionImpl(HloComputation* computation) : computation_(computation) {} - ~HorizontalFusionImpl() {} + ~HorizontalLoopFusionImpl() {} StatusOr Run(); @@ -114,7 +114,7 @@ class HorizontalFusionImpl { }; HloComputation* computation_; -}; // HorizontalFusionImpl +}; // HorizontalLoopFusionImpl bool IsFusionSupported(const HloInstruction& instr) { // Support only kLoop fusion now. @@ -203,7 +203,7 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) { return true; } -void HorizontalFusionImpl::FusionCandidates::Initialize( +void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* consumer) { // First, find out all fusion instructions. We will filter out // unsupported/non-profitable cases below. @@ -257,7 +257,7 @@ void HorizontalFusionImpl::FusionCandidates::Initialize( // Gets a next span of fusion instructions to be fused. absl::Span -HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() { +HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { if (pos_ >= fusion_instrs_.size()) { return absl::Span(); } @@ -315,7 +315,7 @@ HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() { return absl::MakeSpan(fusion_instrs_).subspan(left, right - left); } -Status HorizontalFusionImpl::CreateFusedComputation( +Status HorizontalLoopFusionImpl::CreateFusedComputation( absl::Span fused_fusion_instrs, std::unique_ptr* uniq_computation, std::vector* bound_operands) { @@ -423,7 +423,7 @@ Status HorizontalFusionImpl::CreateFusedComputation( return Status::OK(); } -Status HorizontalFusionImpl::Fuse( +Status HorizontalLoopFusionImpl::Fuse( absl::Span fused_fusion_instrs) { // Fuse fused_fusion_instrs and replace them with the new fused computation. std::unique_ptr uniq_computation; @@ -465,7 +465,7 @@ Status HorizontalFusionImpl::Fuse( return Status::OK(); } -StatusOr HorizontalFusionImpl::Run() { +StatusOr HorizontalLoopFusionImpl::Run() { bool changed = false; XLA_VLOG_LINES(3, computation_->ToString()); @@ -474,7 +474,7 @@ StatusOr HorizontalFusionImpl::Run() { computation_->MakeInstructionPostOrder(); for (size_t i = 0; i < def_to_use_order.size(); ++i) { auto consumer = def_to_use_order[i]; - HorizontalFusionImpl::FusionCandidates fusion_candidates(consumer); + HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer); while (true) { auto fusions = fusion_candidates.GetNextSpanOfFusions(); if (fusions.empty()) { @@ -494,13 +494,13 @@ StatusOr HorizontalFusionImpl::Run() { } // namespace -StatusOr GpuHorizontalFusion::RunOnComputation( +StatusOr GpuHorizontalLoopFusion::RunOnComputation( HloComputation* computation) { - HorizontalFusionImpl horizontal_fusion_impl(computation); + HorizontalLoopFusionImpl horizontal_fusion_impl(computation); return horizontal_fusion_impl.Run(); } -StatusOr GpuHorizontalFusion::Run(HloModule* module) { +StatusOr GpuHorizontalLoopFusion::Run(HloModule* module) { bool changed = false; VLOG(2) << "Run horizontal fusion."; for (auto* comp : module->MakeNonfusionComputations()) { diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.h b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h similarity index 91% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion.h rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h index 9a804949b1c..3824c5df352 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -94,11 +94,13 @@ namespace gpu { // output dims of the concatenate will be used as the kernel launch dims. // Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the // outputs of Mul and Add are row-major. -class GpuHorizontalFusion : public HloModulePass { +class GpuHorizontalLoopFusion : public HloModulePass { public: - GpuHorizontalFusion() {} + GpuHorizontalLoopFusion() {} - absl::string_view name() const override { return "gpu_horizontal_fusion"; } + absl::string_view name() const override { + return "gpu_horizontal_loop_fusion"; + } StatusOr Run(HloModule* module) override; @@ -109,4 +111,4 @@ class GpuHorizontalFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc similarity index 92% rename from tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc rename to tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc index bad589964ff..5b8f4f3cc0d 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h" +#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -37,9 +37,9 @@ namespace { namespace op = xla::testing::opcode_matchers; -class HorizontalFusionTest : public HloTestBase {}; +class HorizontalLoopFusionTest : public HloTestBase {}; -TEST_F(HorizontalFusionTest, BasicTest) { +TEST_F(HorizontalLoopFusionTest, BasicTest) { auto module = ParseAndReturnVerifiedModule(R"( HloModule BasicTest @@ -67,10 +67,9 @@ TEST_F(HorizontalFusionTest, BasicTest) { ROOT tuple.1 = (f16[1024]{0}, f16[123]{0}) tuple(fusion.1, fusion.2) } -)") - .ValueOrDie(); +)").ValueOrDie(); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); const HloInstruction* entry_root = @@ -88,7 +87,7 @@ TEST_F(HorizontalFusionTest, BasicTest) { } // Horizontal fusion should not be triggered as fusion will create cycles. -TEST_F(HorizontalFusionTest, NegativeTestForCycle) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForCycle @@ -119,13 +118,12 @@ TEST_F(HorizontalFusionTest, NegativeTestForCycle) { ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0}) tuple(fusion.1, fusion.2, add.2) } -)") - .ValueOrDie(); +)").ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } -TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForIncompatibleTypes @@ -155,13 +153,12 @@ TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) { ROOT tuple.1 = (f16[1024]{0}, s32[123]{0}) tuple(fusion.1, fusion.2) } -)") - .ValueOrDie(); +)").ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } -TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { +TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { auto module = ParseAndReturnVerifiedModule(R"( HloModule MergeSharedFusionInstruction @@ -183,14 +180,13 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2) add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2) ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2) -})") - .ValueOrDie(); +})").ValueOrDie(); HloPassPipeline fusion("fusion"); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie()); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); VLOG(2) << "Dump after horizontal fusion:"; VLOG(2) << module->ToString(); @@ -198,7 +194,7 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) { EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) { +TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) { HloComputation::Builder builder(TestName()); std::vector var_outs; @@ -229,7 +225,7 @@ TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { +TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { auto module = ParseAndReturnVerifiedModule(R"( HloModule HeterogeneousMultiOutputFusions @@ -277,10 +273,9 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0}) tuple(gte.1, gte.2, gte.3, gte.4) } -)") - .ValueOrDie(); +)").ValueOrDie(); - EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie()); VLOG(2) << "Dump after horizontal fusion:"; @@ -289,7 +284,7 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) { EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0})); } -TEST_F(HorizontalFusionTest, RMSPropLike) { +TEST_F(HorizontalLoopFusionTest, RMSPropLike) { HloComputation::Builder builder(TestName()); std::vector all_outputs; @@ -364,7 +359,7 @@ TEST_F(HorizontalFusionTest, RMSPropLike) { EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5})); } -TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) { +TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) { auto module = ParseAndReturnVerifiedModule(R"( HloModule NegativeTestForDynamicUpdateSlice @@ -397,10 +392,9 @@ TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) { f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1 f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2 ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2) - })") - .ValueOrDie(); + })").ValueOrDie(); - EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie()); } } // namespace