Add EvaluateNodes to HoistFactorDiv test.
PiperOrigin-RevId: 195685340
This commit is contained in:
parent
ac630df3cb
commit
b2888c66e6
@ -696,6 +696,9 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
|
||||
item.fetch = {"id"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
EXPECT_EQ(1, tensors_expected.size());
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
EnableOnlyHoistCommonFactor(&optimizer);
|
||||
|
||||
@ -734,6 +737,13 @@ TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
|
||||
EXPECT_EQ("id", id_node->name());
|
||||
EXPECT_EQ(HoistDivName("add"), id_node->input(0));
|
||||
}
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
EXPECT_EQ(1, tensors.size());
|
||||
if (use_ints) {
|
||||
test::ExpectTensorEqual<int32>(tensors_expected[0], tensors[0]);
|
||||
} else {
|
||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user