Addressed review comments
This commit is contained in:
parent
7b31a4830c
commit
a770d78407
@ -14,8 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/pattern_utils.h"
|
||||
#include "tensorflow/core/util/dump_graph.h"
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -169,7 +167,7 @@ TEST_F(PatternMatcherTest, Tree) {
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -251,7 +249,7 @@ TEST_F(PatternMatcherTest, DAG) {
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -337,7 +335,7 @@ TEST_F(PatternMatcherTest, DAGExternalDependent) {
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -358,7 +356,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGelu) {
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -394,7 +392,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) {
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -415,7 +413,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
@ -445,12 +443,12 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
|
||||
fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1));
|
||||
mutation->AddNode(std::move(fused_node), &status);
|
||||
TF_ASSERT_OK(status);
|
||||
mutation->Apply();
|
||||
TF_EXPECT_OK(mutation->Apply());
|
||||
// Remove nodes that are marked as NodeStatus::kRemove.
|
||||
for (auto const& node_idx : remove_node_indices) {
|
||||
mutation->RemoveNode(graph_view.GetNode(node_idx));
|
||||
}
|
||||
mutation->Apply();
|
||||
TF_EXPECT_OK(mutation->Apply());
|
||||
|
||||
// After mutation number of nodes.
|
||||
int num_nodes_after = graph_view.NumNodes();
|
||||
|
Loading…
Reference in New Issue
Block a user