[XLA:GPU] Mark bitcasts as eligible for fusion.

Currently this never happens because we only turn rehaspes into bitcasts after
layout assignment. This changes when layout assignment runs before fusion. Once
layouts are available the pipeline turns reshapes into bitcasts, which would be
left unfused without this change.

PiperOrigin-RevId: 187999864
This commit is contained in:
Benjamin Kramer 2018-03-06 03:31:45 -08:00 committed by TensorFlower Gardener
parent 834093de42
commit bec6e47cf9
5 changed files with 51 additions and 2 deletions

View File

@ -1722,6 +1722,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
return ir_builder_->CreateLoad(ret_value_addr);
};
case HloOpcode::kBitcast:
case HloOpcode::kReshape:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));

View File

@ -397,6 +397,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)

View File

@ -26,6 +26,7 @@ namespace {
bool IsFusile(const HloInstruction& hlo) {
return (hlo.IsElementwise() && hlo.operand_count() > 0) ||
hlo.opcode() == HloOpcode::kBitcast ||
hlo.opcode() == HloOpcode::kBroadcast ||
hlo.opcode() == HloOpcode::kConcatenate ||
hlo.opcode() == HloOpcode::kDynamicSlice ||

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace op = xla::testing::opcode_matchers;
@ -163,5 +164,49 @@ TEST_F(InstructionFusionTest, GetTupleElementFused) {
EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode());
}
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
auto module = tools::Parse(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
p0 = f32[4,1,1]{2,1,0} parameter(0)
p1 = f32[4,1]{1,0} parameter(1)
bitcast = f32[4,1]{1,0} bitcast(p0)
ROOT add = f32[4,1] add(bitcast, p1)
})")
.ValueOrDie();
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Add(op::Bitcast(op::Parameter()), op::Parameter()));
}
TEST_F(InstructionFusionTest, AddIntoBitcast) {
auto module = tools::Parse(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
p0 = f32[4,1,1]{2,1,0} parameter(0)
p1 = f32[4,1]{1,0} parameter(1)
add = f32[4,1] add(p0, p1)
ROOT bitcast = f32[4,1,1] bitcast(add)
})")
.ValueOrDie();
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Bitcast(op::Add(op::Parameter(), op::Parameter())));
}
} // namespace gpu
} // namespace xla

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@ -49,11 +50,11 @@ void LLVMIRGenTestBase::CompileAndVerifyIr(
std::unique_ptr<HloModule> hlo_module, const string& pattern,
bool match_optimized_ir) {
SetIrHook(match_optimized_ir);
ASSERT_TRUE(CompileToExecutable(std::move(hlo_module)).ok());
TF_ASSERT_OK(CompileToExecutable(std::move(hlo_module)).status());
ResetIrHook();
StatusOr<bool> filecheck_result = RunFileCheck(ir_, pattern);
ASSERT_TRUE(filecheck_result.ok());
TF_ASSERT_OK(filecheck_result.status());
EXPECT_TRUE(filecheck_result.ValueOrDie());
}