STT-tensorflow/tensorflow/compiler/xla/service/instruction_fusion_test.cc
Yunxing Dai 2178657934 rng loop fusion
PiperOrigin-RevId: 314951531
Change-Id: If2842a456471eaabd1fbab8c792fd84d95b4f792
2020-06-05 10:34:16 -07:00

476 lines
18 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace op = xla::testing::opcode_matchers;
using InstructionFusionTest = HloTestBase;
// Subclass of InstructionFusion exposing the protected methods Fuse and
// FuseIntoMultiOutput for testing.
class InstructionFusionForTesting : public InstructionFusion {
public:
explicit InstructionFusionForTesting(HloModule* module)
: InstructionFusion(InstructionFusion::IsExpensive) {
module_ = module;
computation_ = module->entry_computation();
}
HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer) override {
return InstructionFusion::Fuse(producer, consumer);
}
HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
HloInstruction* consumer) override {
return InstructionFusion::FuseIntoMultiOutput(producer, consumer);
}
};
TEST_F(InstructionFusionTest, FuseInstructions) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
add = f32[4,3]{1,0} add(p0, p0)
ROOT sub = f32[4,3]{1,0} subtract(add, p0)
})")
.ValueOrDie();
HloInstruction* sub = module->entry_computation()->root_instruction();
HloInstruction* add = sub->mutable_operand(0);
HloInstruction* fusion =
InstructionFusionForTesting(module.get()).Fuse(add, sub);
ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
EXPECT_THAT(fusion->fused_expression_root(),
op::Subtract(op::Add(), op::Parameter()))
<< module->ToString();
}
TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
fused_computation {
p1 = f32[4,3] parameter(0)
add = f32[4,3] add(p1, p1)
}
ENTRY entry_computation {
p0 = f32[4,3] parameter(0)
abs = f32[4,3] abs(p0)
ROOT fusion = f32[4,3] fusion(abs), kind=kLoop, calls=fused_computation
})")
.ValueOrDie();
HloInstruction* root = module->entry_computation()->root_instruction();
HloInstruction* abs = root->mutable_operand(0);
HloInstruction* fusion =
InstructionFusionForTesting(module.get()).Fuse(abs, root);
ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
EXPECT_THAT(fusion->fused_expression_root(), op::Add(op::Abs(), op::Abs()))
<< module->ToString();
}
TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
abs = f32[4,3]{1,0} abs(p0)
tanh = f32[4,3]{1,0} tanh(abs)
ROOT add = f32[4,3]{1,0} add(abs, tanh)
})")
.ValueOrDie();
HloInstruction* root = module->entry_computation()->root_instruction();
HloInstruction* abs = root->mutable_operand(0);
HloInstruction* tanh = root->mutable_operand(1);
HloInstruction* fusion =
InstructionFusionForTesting(module.get()).FuseIntoMultiOutput(abs, tanh);
ASSERT_THAT(fusion, op::Fusion()) << module->ToString();
EXPECT_THAT(fusion->fused_expression_root(), op::Tuple(op::Tanh(), op::Abs()))
<< module->ToString();
}
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1));
auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send =
builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
builder.AddInstruction(HloInstruction::CreateSendDone(send));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
}
// Counts the number of HLO ops with a given op code in the specified module.
static int Count(const HloModule& module, HloOpcode op) {
int count = 0;
for (const auto* computation : module.computations()) {
for (const auto* instruction : computation->instructions()) {
if (instruction->opcode() == op) {
++count;
}
}
}
return count;
}
TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
add = f32[4,3]{1,0} add(p0, p0)
ROOT root = f32[4,3]{1,0} subtract(add, add)
})")
.ValueOrDie();
// Expect the add and subtraction to be fused.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
}
TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
// Make sure we do not duplicate the add, as we cannot fuse through the rng.
//
// (p0, p1) -> add -------------------------> sub
// \-> abs1 -> rng -> abs2 -/
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[] parameter(0)
p1 = f32[] parameter(1)
add = f32[] add(p0, p1)
abs1 = f32[] abs(add)
rng = f32[] rng(p1, abs1), distribution=rng_uniform
abs2 = f32[] abs(rng)
abs3 = f32[] abs(rng)
ROOT root = f32[] subtract(abs2, add)
})")
.ValueOrDie();
// We expect abs2 to be fused into root.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
<< module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
// Use a log node with a second consumer to break the fusion.
//
// (p0, p1) -> add -------------------------> sub
// \-> abs1 -> log -> abs2 -/
// \-> send -> send-done
module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
add = f32[4,3]{1,0} add(p0, p1)
abs1 = f32[4,3]{1,0} abs(add)
log = f32[4,3]{1,0} log(abs1)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=1
send-done = token[] send-done(send), channel_id=1
abs2 = f32[4,3]{1,0} abs(log)
ROOT root = f32[4,3]{1,0} subtract(abs2, add)
})")
.ValueOrDie();
// We expect abs2 to be fused into root and abs1 to be fused into log.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 2) << module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
// Make sure we still fuse ops where one operand in the chain to the producer
// can't be fused.
//
// (p0, p1) ---> add1 -----------> sub
// \ \-> add2 -/
// \-> log -/
// \-> send -> send-done
module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
add1 = f32[4,3]{1,0} add(p0, p1)
log = f32[4,3]{1,0} log(p0)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=1
send-done = token[] send-done(send), channel_id=1
add2 = f32[4,3]{1,0} add(log, add1)
ROOT root = f32[4,3]{1,0} subtract(add1, add2)
})")
.ValueOrDie();
// Expect the add1 and add2 to be fused into root.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
// A variant of the above that allows the algorithm to put add2 into the set
// of unfusible ops to short-circuit the decision whether add1 should be fused
// into sub2.
//
// /---------------\
// (p0, p1) ---> add1 ---> add2 ------> sub2
// \------> sub1
// log -/
// \-> send
module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
add1 = f32[4,3]{1,0} add(p0, p1)
add2 = f32[4,3]{1,0} add(add1, p1)
log = f32[4,3]{1,0} log(add2)
token0 = token[] after-all()
send = f32[4,3]{1,0} send(log, token0), channel_id=1
send-done = token[] send-done(send), channel_id=1
sub1 = f32[4,3]{1,0} subtract(log, add2)
sub2 = f32[4,3]{1,0} subtract(add2, add1)
ROOT root = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub1, sub2)
})")
.ValueOrDie();
// Expect sub1 and sub2 to be fused into root.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
op::Subtract(op::Parameter(), op::Parameter())))
<< module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
}
TEST_F(InstructionFusionTest, AllowUnaryDuplication) {
HloComputation::Builder builder(TestName());
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto param0 =
builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "0"));
HloInstruction* unary1 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kFloor, param0));
auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send =
builder.AddInstruction(HloInstruction::CreateSend(unary1, token, 0));
builder.AddInstruction(HloInstruction::CreateSendDone(send));
HloInstruction* unary2 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, unary1));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
auto shape = ShapeUtil::MakeShape(F32, {16, 16});
auto small_shape = ShapeUtil::MakeShape(F32, {16});
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, small_shape, "0"));
auto param1 =
builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "1"));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(shape, param0, {0}));
HloInstruction* binary1 = builder.AddInstruction(
HloInstruction::CreateBinary(shape, HloOpcode::kAdd, broadcast, param1));
auto token = builder.AddInstruction(HloInstruction::CreateToken());
auto send =
builder.AddInstruction(HloInstruction::CreateSend(binary1, token, 0));
builder.AddInstruction(HloInstruction::CreateSendDone(send));
HloInstruction* unary = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kAbs, binary1));
auto module = CreateNewVerifiedModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(unary, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
// Make sure we do duplicate the add of the same values, even though we cannot
// fuse through the rng.
//
// p0 -> add -------------------------> sub
// \-> abs1 -> rng -> abs2 -/
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[] parameter(0)
add = f32[] add(p0, p0)
abs1 = f32[] abs(add)
rng = f32[] rng(p0, abs1), distribution=rng_uniform
abs2 = f32[] abs(rng)
abs3 = f32[] abs(rng)
ROOT root = f32[] subtract(abs2, add)
})")
.ValueOrDie();
// We expect abs2 to be fused into root.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Subtract(op::Abs(op::Parameter()),
op::Add(op::Parameter(), op::Parameter())))
<< module->ToString();
// Make sure the add has been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
}
TEST_F(InstructionFusionTest, FuseDiamondGraphsNoDuplication) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
add = f32[100] add(p0, p1)
slice1 = f32[99] slice(add), slice={[0:99:1]}
slice2 = f32[99] slice(add), slice={[1:100:1]}
ROOT add2 = f32[99] add(slice1, slice2)
})")
.ValueOrDie();
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
// 'add' would originally need to be duplicated if fused. However after its
// two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one
// user and can now be also fused.
EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter()));
}
TEST_F(InstructionFusionTest, FuseDiamondGraphsAllowDuplication) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
add = f32[100] add(p0, p1)
slice1 = f32[99] slice(add), slice={[0:99:1]}
slice2 = f32[99] slice(add), slice={[1:100:1]}
ROOT add2 = f32[99] add(slice1, slice2)
})")
.ValueOrDie();
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
// 'add' would originally need to be duplicated if fused. However after its
// two users 'slice1' and 'slice2' are fused into 'add2', 'add' has only one
// user and can now be also fused.
EXPECT_THAT(root, op::Fusion(op::Parameter(), op::Parameter()));
}
TEST_F(InstructionFusionTest,
WideningConvertsAreAlwaysDuplicableIntoConsumers) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule test_module
ENTRY Test {
p0 = f16[100] parameter(0)
c = f32[100] convert(p0)
add = f32[100] add(c, c)
ROOT mul = f32[100] multiply(c, c)
})")
.ValueOrDie();
// The convert should be fused into the add and mul, even though may_duplicate
// is false, because it's always beneficial to fuse/duplicate widening
// converts into consumers.
EXPECT_TRUE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/false)
.Run(module.get())
.ValueOrDie())
<< module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion(op::Parameter()));
}
} // namespace xla