Pre-factoring: Fix overly specific test expectations to prepare for multi-output fusion.

PiperOrigin-RevId: 196514026
This commit is contained in:
A. Unique TensorFlower 2018-05-14 09:06:25 -07:00 committed by TensorFlower Gardener
parent 4b1fa0ccdc
commit 0c59fdb949

View File

@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
EXPECT_FALSE(
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
.ValueOrDie())
<< module->ToString();
}
// Counts the number of HLO ops with a given op code in the specified module.
@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
<< module->ToString();
// Make sure the add hasn't been duplicated.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
.Run(module.get())
.ValueOrDie())
<< module->ToString();
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
op::Subtract(op::Parameter(), op::Parameter())))
<< module->ToString();
// Make sure we didn't duplicate any adds.
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();