Pre-factoring: Fix overly specific test expectations to prepare for multi-output fusion.
PiperOrigin-RevId: 196514026
This commit is contained in:
parent
4b1fa0ccdc
commit
0c59fdb949
@ -92,7 +92,8 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) {
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
.ValueOrDie())
|
||||
<< module->ToString();
|
||||
}
|
||||
|
||||
// Counts the number of HLO ops with a given op code in the specified module.
|
||||
@ -151,7 +152,11 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
|
||||
.Run(module.get())
|
||||
.ValueOrDie())
|
||||
<< module->ToString();
|
||||
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Subtract(op::Abs(op::Parameter()), op::Parameter()))
|
||||
<< module->ToString();
|
||||
|
||||
// Make sure the add hasn't been duplicated.
|
||||
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString();
|
||||
@ -244,7 +249,12 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
|
||||
.Run(module.get())
|
||||
.ValueOrDie())
|
||||
<< module->ToString();
|
||||
EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1) << module->ToString();
|
||||
root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Tuple(op::Subtract(op::Parameter(), op::Parameter()),
|
||||
op::Subtract(op::Parameter(), op::Parameter())))
|
||||
<< module->ToString();
|
||||
|
||||
// Make sure we didn't duplicate any adds.
|
||||
EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString();
|
||||
|
Loading…
Reference in New Issue
Block a user