Addressed review comments

This commit is contained in:
mdfaijul 2021-01-05 21:12:10 -08:00
parent 7b31a4830c
commit a770d78407

View File

@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/pattern_utils.h"
#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -169,7 +167,7 @@ TEST_F(PatternMatcherTest, Tree) {
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -251,7 +249,7 @@ TEST_F(PatternMatcherTest, DAG) {
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -337,7 +335,7 @@ TEST_F(PatternMatcherTest, DAGExternalDependent) {
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -358,7 +356,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGelu) {
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -394,7 +392,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) {
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -415,7 +413,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
graph_view.SortTopologically(/*ignore_cycles=*/false, {});
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
@ -445,12 +443,12 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1));
mutation->AddNode(std::move(fused_node), &status);
TF_ASSERT_OK(status);
mutation->Apply();
TF_EXPECT_OK(mutation->Apply());
// Remove nodes that are marked as NodeStatus::kRemove.
for (auto const& node_idx : remove_node_indices) {
mutation->RemoveNode(graph_view.GetNode(node_idx));
}
mutation->Apply();
TF_EXPECT_OK(mutation->Apply());
// After mutation number of nodes.
int num_nodes_after = graph_view.NumNodes();