[XLA:GPU] Mark bitcasts as eligible for fusion.
Currently this never happens because we only turn rehaspes into bitcasts after layout assignment. This changes when layout assignment runs before fusion. Once layouts are available the pipeline turns reshapes into bitcasts, which would be left unfused without this change. PiperOrigin-RevId: 187999864
This commit is contained in:
parent
834093de42
commit
bec6e47cf9
tensorflow/compiler/xla
@ -1722,6 +1722,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
||||
return ir_builder_->CreateLoad(ret_value_addr);
|
||||
};
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kReshape:
|
||||
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
|
||||
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
|
||||
|
@ -397,6 +397,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -26,6 +26,7 @@ namespace {
|
||||
|
||||
bool IsFusile(const HloInstruction& hlo) {
|
||||
return (hlo.IsElementwise() && hlo.operand_count() > 0) ||
|
||||
hlo.opcode() == HloOpcode::kBitcast ||
|
||||
hlo.opcode() == HloOpcode::kBroadcast ||
|
||||
hlo.opcode() == HloOpcode::kConcatenate ||
|
||||
hlo.opcode() == HloOpcode::kDynamicSlice ||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
@ -163,5 +164,49 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
|
||||
EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
|
||||
auto module = tools::Parse(R"(
|
||||
HloModule test_module
|
||||
|
||||
ENTRY BroadcastIntoAdd {
|
||||
p0 = f32[4,1,1]{2,1,0} parameter(0)
|
||||
p1 = f32[4,1]{1,0} parameter(1)
|
||||
bitcast = f32[4,1]{1,0} bitcast(p0)
|
||||
ROOT add = f32[4,1] add(bitcast, p1)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Add(op::Bitcast(op::Parameter()), op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, AddIntoBitcast) {
|
||||
auto module = tools::Parse(R"(
|
||||
HloModule test_module
|
||||
|
||||
ENTRY BroadcastIntoAdd {
|
||||
p0 = f32[4,1,1]{2,1,0} parameter(0)
|
||||
p1 = f32[4,1]{1,0} parameter(1)
|
||||
add = f32[4,1] add(p0, p1)
|
||||
ROOT bitcast = f32[4,1,1] bitcast(add)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Bitcast(op::Add(op::Parameter(), op::Parameter())));
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
@ -49,11 +50,11 @@ void LLVMIRGenTestBase::CompileAndVerifyIr(
|
||||
std::unique_ptr<HloModule> hlo_module, const string& pattern,
|
||||
bool match_optimized_ir) {
|
||||
SetIrHook(match_optimized_ir);
|
||||
ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok());
|
||||
TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status());
|
||||
ResetIrHook();
|
||||
|
||||
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
|
||||
ASSERT_TRUE(filecheck_result.ok());
|
||||
TF_ASSERT_OK(filecheck_result.status());
|
||||
EXPECT_TRUE(filecheck_result.ValueOrDie());
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user