From a770d78407993eb876cbc08263294cefe15ad19c Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Tue, 5 Jan 2021 21:12:10 -0800 Subject: [PATCH] Addressed review comments --- .../core/grappler/utils/pattern_utils_test.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/grappler/utils/pattern_utils_test.cc b/tensorflow/core/grappler/utils/pattern_utils_test.cc index f2ea0b6af92..67dc20b6a5d 100644 --- a/tensorflow/core/grappler/utils/pattern_utils_test.cc +++ b/tensorflow/core/grappler/utils/pattern_utils_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/pattern_utils.h" -#include "tensorflow/core/util/dump_graph.h" - #include "tensorflow/cc/ops/nn_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -169,7 +167,7 @@ TEST_F(PatternMatcherTest, Tree) { MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("e"); SubGraphMatcher graph_matcher(&graph_view); @@ -251,7 +249,7 @@ TEST_F(PatternMatcherTest, DAG) { MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("e"); SubGraphMatcher graph_matcher(&graph_view); @@ -337,7 +335,7 @@ TEST_F(PatternMatcherTest, DAGExternalDependent) { MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("e"); SubGraphMatcher graph_matcher(&graph_view); @@ -358,7 +356,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGelu) { OpTypePattern pattern = GetMatMulBiasAddGeluPattern(); MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("gelu"); SubGraphMatcher graph_matcher(&graph_view); @@ -394,7 +392,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) { OpTypePattern pattern = GetMatMulBiasAddGeluPattern(); MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("gelu"); SubGraphMatcher graph_matcher(&graph_view); @@ -415,7 +413,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) { OpTypePattern pattern = GetMatMulBiasAddGeluPattern(); MutableGraphView graph_view(&graph, &status); TF_ASSERT_OK(status); - graph_view.SortTopologically(/*ignore_cycles=*/false, {}); + TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {})); auto root_node_view = graph_view.GetNode("gelu"); SubGraphMatcher graph_matcher(&graph_view); @@ -445,12 +443,12 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) { fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1)); mutation->AddNode(std::move(fused_node), &status); TF_ASSERT_OK(status); - mutation->Apply(); + TF_EXPECT_OK(mutation->Apply()); // Remove nodes that are marked as NodeStatus::kRemove. for (auto const& node_idx : remove_node_indices) { mutation->RemoveNode(graph_view.GetNode(node_idx)); } - mutation->Apply(); + TF_EXPECT_OK(mutation->Apply()); // After mutation number of nodes. int num_nodes_after = graph_view.NumNodes();