Rename horizontal_fusion to horizontal_loop_fusion.

This commit is contained in:
Trent Lo 2020-09-16 17:58:22 -07:00
parent 7697628be7
commit a96d041ae6
6 changed files with 58 additions and 62 deletions

View File

@ -1153,8 +1153,8 @@ cc_library(
":gpu_layout_assignment",
":gpu_sanitize_constant_names",
":gpu_scatter_expander",
":horizontal_fusion",
":horizontal_input_fusion",
":horizontal_loop_fusion",
":instruction_fusion",
":ir_emission_utils",
":ir_emitter",
@ -1730,9 +1730,9 @@ tf_cc_test(
)
cc_library(
name = "horizontal_fusion",
srcs = ["horizontal_fusion.cc"],
hdrs = ["horizontal_fusion.h"],
name = "horizontal_loop_fusion",
srcs = ["horizontal_loop_fusion.cc"],
hdrs = ["horizontal_loop_fusion.h"],
deps = [
":gpu_fusible",
"//tensorflow/compiler/xla:shape_util",
@ -1747,11 +1747,11 @@ cc_library(
)
tf_cc_test(
name = "horizontal_fusion_test",
srcs = ["horizontal_fusion_test.cc"],
name = "horizontal_loop_fusion_test",
srcs = ["horizontal_loop_fusion_test.cc"],
deps = [
":fusion_merger",
":horizontal_fusion",
":horizontal_loop_fusion",
":instruction_fusion",
":multi_output_fusion",
"//tensorflow/compiler/jit:xla_gpu_jit",

View File

@ -29,7 +29,7 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@ -59,8 +59,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
@ -302,7 +302,7 @@ Status GpuCompiler::OptimizeHloModule(
TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status());
HloPassPipeline horizontal_fusion("horizontal_fusion");
horizontal_fusion.AddPass<GpuHorizontalFusion>();
horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);

View File

@ -27,11 +27,11 @@ namespace gpu {
// This optimization pass horizontally fuses kInput fusions to both reduce the
// kernel launch overhead and increase parallelism degree. See
// GpuHorizontalFusion for general description and motivation about horizontal
// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals
// with kInput fusions.
// GpuHorizontalLoopFusion for general description and motivation about
// horizontal fusion. GpuHorizontalLoopFusion deals with kLoop fusions while
// this pass deals kInput fusions.
//
// Following GpuHorizontalFusion, a simple yet effective heuristic is used
// Following GpuHorizontalLoopFusion, a simple yet effective heuristic is used
// to search the fusion candidates while avoiding creating cycles. That is,
// we simply search for fusion candidates by looking for instructions whose
// outputs are all consumed by the same instruction. This catches the typical

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include <algorithm>
@ -67,12 +67,12 @@ PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) {
return first_output_type;
}
class HorizontalFusionImpl {
class HorizontalLoopFusionImpl {
public:
explicit HorizontalFusionImpl(HloComputation* computation)
explicit HorizontalLoopFusionImpl(HloComputation* computation)
: computation_(computation) {}
~HorizontalFusionImpl() {}
~HorizontalLoopFusionImpl() {}
StatusOr<bool> Run();
@ -114,7 +114,7 @@ class HorizontalFusionImpl {
};
HloComputation* computation_;
}; // HorizontalFusionImpl
}; // HorizontalLoopFusionImpl
bool IsFusionSupported(const HloInstruction& instr) {
// Support only kLoop fusion now.
@ -203,7 +203,7 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
return true;
}
void HorizontalFusionImpl::FusionCandidates::Initialize(
void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
HloInstruction* consumer) {
// First, find out all fusion instructions. We will filter out
// unsupported/non-profitable cases below.
@ -257,7 +257,7 @@ void HorizontalFusionImpl::FusionCandidates::Initialize(
// Gets a next span of fusion instructions to be fused.
absl::Span<HloInstruction*>
HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
if (pos_ >= fusion_instrs_.size()) {
return absl::Span<HloInstruction*>();
}
@ -315,7 +315,7 @@ HorizontalFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
return absl::MakeSpan(fusion_instrs_).subspan(left, right - left);
}
Status HorizontalFusionImpl::CreateFusedComputation(
Status HorizontalLoopFusionImpl::CreateFusedComputation(
absl::Span<HloInstruction*> fused_fusion_instrs,
std::unique_ptr<HloComputation>* uniq_computation,
std::vector<HloInstruction*>* bound_operands) {
@ -423,7 +423,7 @@ Status HorizontalFusionImpl::CreateFusedComputation(
return Status::OK();
}
Status HorizontalFusionImpl::Fuse(
Status HorizontalLoopFusionImpl::Fuse(
absl::Span<HloInstruction*> fused_fusion_instrs) {
// Fuse fused_fusion_instrs and replace them with the new fused computation.
std::unique_ptr<HloComputation> uniq_computation;
@ -465,7 +465,7 @@ Status HorizontalFusionImpl::Fuse(
return Status::OK();
}
StatusOr<bool> HorizontalFusionImpl::Run() {
StatusOr<bool> HorizontalLoopFusionImpl::Run() {
bool changed = false;
XLA_VLOG_LINES(3, computation_->ToString());
@ -474,7 +474,7 @@ StatusOr<bool> HorizontalFusionImpl::Run() {
computation_->MakeInstructionPostOrder();
for (size_t i = 0; i < def_to_use_order.size(); ++i) {
auto consumer = def_to_use_order[i];
HorizontalFusionImpl::FusionCandidates fusion_candidates(consumer);
HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
while (true) {
auto fusions = fusion_candidates.GetNextSpanOfFusions();
if (fusions.empty()) {
@ -494,13 +494,13 @@ StatusOr<bool> HorizontalFusionImpl::Run() {
} // namespace
StatusOr<bool> GpuHorizontalFusion::RunOnComputation(
StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
HloComputation* computation) {
HorizontalFusionImpl horizontal_fusion_impl(computation);
HorizontalLoopFusionImpl horizontal_fusion_impl(computation);
return horizontal_fusion_impl.Run();
}
StatusOr<bool> GpuHorizontalFusion::Run(HloModule* module) {
StatusOr<bool> GpuHorizontalLoopFusion::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Run horizontal fusion.";
for (auto* comp : module->MakeNonfusionComputations()) {

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -94,11 +94,13 @@ namespace gpu {
// output dims of the concatenate will be used as the kernel launch dims.
// Instruction bitcasts can be used for Reshape2 and Reshape3 as long as the
// outputs of Mul and Add are row-major.
class GpuHorizontalFusion : public HloModulePass {
class GpuHorizontalLoopFusion : public HloModulePass {
public:
GpuHorizontalFusion() {}
GpuHorizontalLoopFusion() {}
absl::string_view name() const override { return "gpu_horizontal_fusion"; }
absl::string_view name() const override {
return "gpu_horizontal_loop_fusion";
}
StatusOr<bool> Run(HloModule* module) override;
@ -109,4 +111,4 @@ class GpuHorizontalFusion : public HloModulePass {
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_FUSION_H_
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/horizontal_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
@ -37,9 +37,9 @@ namespace {
namespace op = xla::testing::opcode_matchers;
class HorizontalFusionTest : public HloTestBase {};
class HorizontalLoopFusionTest : public HloTestBase {};
TEST_F(HorizontalFusionTest, BasicTest) {
TEST_F(HorizontalLoopFusionTest, BasicTest) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule BasicTest
@ -67,10 +67,9 @@ TEST_F(HorizontalFusionTest, BasicTest) {
ROOT tuple.1 = (f16[1024]{0}, f16[123]{0})
tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();
)").ValueOrDie();
EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
const HloInstruction* entry_root =
@ -88,7 +87,7 @@ TEST_F(HorizontalFusionTest, BasicTest) {
}
// Horizontal fusion should not be triggered as fusion will create cycles.
TEST_F(HorizontalFusionTest, NegativeTestForCycle) {
TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForCycle
@ -119,13 +118,12 @@ TEST_F(HorizontalFusionTest, NegativeTestForCycle) {
ROOT tuple.1 = (f16[123]{0}, f16[123]{0}, f16[123]{0})
tuple(fusion.1, fusion.2, add.2)
}
)")
.ValueOrDie();
)").ValueOrDie();
EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}
TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) {
TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForIncompatibleTypes
@ -155,13 +153,12 @@ TEST_F(HorizontalFusionTest, NegativeTestForIncompatibleTypes) {
ROOT tuple.1 = (f16[1024]{0}, s32[123]{0})
tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();
)").ValueOrDie();
EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}
TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) {
TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule MergeSharedFusionInstruction
@ -183,14 +180,13 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) {
mul.2.2 = f32[321,5]{1,0} multiply(param.2.3, broadcast.2)
add.2 = f32[321,5]{1,0} add(mul.2.1, mul.2.2)
ROOT tuple = (f32[4,1024]{1,0}, f32[321,5]{1,0}) tuple(add.1, add.2)
})")
.ValueOrDie();
})").ValueOrDie();
HloPassPipeline fusion("fusion");
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<xla::gpu::GpuInstructionFusion>(/*may_duplicate=*/true);
EXPECT_TRUE(fusion.Run(module.get()).ValueOrDie());
EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
VLOG(2) << "Dump after horizontal fusion:";
VLOG(2) << module->ToString();
@ -198,7 +194,7 @@ TEST_F(HorizontalFusionTest, HorizontalFusionAfterVerticalFusion) {
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
}
TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) {
TEST_F(HorizontalLoopFusionTest, GradientDescentOptimizerLike) {
HloComputation::Builder builder(TestName());
std::vector<HloInstruction*> var_outs;
@ -229,7 +225,7 @@ TEST_F(HorizontalFusionTest, GradientDescentOptimizerLike) {
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{0, 0}));
}
TEST_F(HorizontalFusionTest, FusingDifferentOutputs) {
TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule HeterogeneousMultiOutputFusions
@ -277,10 +273,9 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) {
ROOT tuple.1 = (f16[1024]{0}, f16[1024]{0}, f16[123]{0}, f16[123]{0})
tuple(gte.1, gte.2, gte.3, gte.4)
}
)")
.ValueOrDie();
)").ValueOrDie();
EXPECT_TRUE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
VLOG(2) << "Dump after horizontal fusion:";
@ -289,7 +284,7 @@ TEST_F(HorizontalFusionTest, FusingDifferentOutputs) {
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
}
TEST_F(HorizontalFusionTest, RMSPropLike) {
TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
HloComputation::Builder builder(TestName());
std::vector<HloInstruction*> all_outputs;
@ -364,7 +359,7 @@ TEST_F(HorizontalFusionTest, RMSPropLike) {
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
}
TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) {
TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForDynamicUpdateSlice
@ -397,10 +392,9 @@ TEST_F(HorizontalFusionTest, NegativeTestForDynamicUpdateSlice) {
f1 = f16[5,9,10] fusion(p.00, p.10, p.20), kind=kLoop, calls=fusion.1
f2 = f16[5,9,10] fusion(p.01, p.11, p.21), kind=kLoop, calls=fusion.2
ROOT tuple = (f16[5,9,10],f16[5,9,10]) tuple(f1, f2)
})")
.ValueOrDie();
})").ValueOrDie();
EXPECT_FALSE(GpuHorizontalFusion().Run(module.get()).ValueOrDie());
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}
} // namespace