Add EvaluateNodes to HoistFactorDiv test.
PiperOrigin-RevId: 195685340
This commit is contained in:
parent
ac630df3cb
commit
b2888c66e6
@ -696,6 +696,9 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
|
|||||||
item.fetch = {"id"};
|
item.fetch = {"id"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||||
|
EXPECT_EQ(1, tensors_expected.size());
|
||||||
|
|
||||||
ArithmeticOptimizer optimizer;
|
ArithmeticOptimizer optimizer;
|
||||||
EnableOnlyHoistCommonFactor(&optimizer);
|
EnableOnlyHoistCommonFactor(&optimizer);
|
||||||
|
|
||||||
@ -734,6 +737,13 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
|
|||||||
EXPECT_EQ("id", id_node->name());
|
EXPECT_EQ("id", id_node->name());
|
||||||
EXPECT_EQ(HoistDivName("add"), id_node->input(0));
|
EXPECT_EQ(HoistDivName("add"), id_node->input(0));
|
||||||
}
|
}
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
|
EXPECT_EQ(1, tensors.size());
|
||||||
|
if (use_ints) {
|
||||||
|
test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
|
||||||
|
} else {
|
||||||
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user