From f1936485847d5d2956aad5a960b5a97fa406c5a9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 6 Mar 2019 16:08:57 -0800 Subject: [PATCH] Add an option to GraphTopologyView to ignore control edges in the graph. PiperOrigin-RevId: 237139933 --- .../core/grappler/graph_topology_view.cc | 26 ++++++++++++-- .../core/grappler/graph_topology_view.h | 4 +++ .../core/grappler/graph_topology_view_test.cc | 36 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/grappler/graph_topology_view.cc b/tensorflow/core/grappler/graph_topology_view.cc index 79e2f9a92fd..86d86c3aa72 100644 --- a/tensorflow/core/grappler/graph_topology_view.cc +++ b/tensorflow/core/grappler/graph_topology_view.cc @@ -40,7 +40,8 @@ inline void SortAndRemoveDuplicates(T* v) { Status GraphTopologyView::InitializeFromGraph( const GraphDef& graph, - const absl::Span ephemeral_edges) { + const absl::Span ephemeral_edges, + bool ignore_control_edges) { if (graph_ != nullptr) { return errors::InvalidArgument("GraphTopologyView is already initialized."); } @@ -63,7 +64,6 @@ Status GraphTopologyView::InitializeFromGraph( for (const GraphView::Edge& edge : ephemeral_edges) { const auto src = node_name_to_index_.find(edge.src.node->name()); const bool valid_src = src != node_name_to_index_.end(); - if (!valid_src) { const string error_message = absl::StrCat("Non-existent src node: ", edge.src.node->name()); @@ -90,6 +90,9 @@ Status GraphTopologyView::InitializeFromGraph( if (valid_dst && valid_src) { const int src_idx = src->second; const int dst_idx = dst->second; + if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) { + continue; + } fanins_[dst_idx].push_back(src_idx); fanouts_[src_idx].push_back(dst_idx); } @@ -102,6 +105,9 @@ Status GraphTopologyView::InitializeFromGraph( for (const string& input : node.input()) { TensorId tensor = ParseTensorName(input); + if (ignore_control_edges && IsTensorIdControl(tensor)) { + continue; + } const auto it = node_name_to_index_.find(tensor.node()); const bool valid_input = it != node_name_to_index_.end(); @@ -134,8 +140,22 @@ Status GraphTopologyView::InitializeFromGraph( return Status::OK(); } +Status GraphTopologyView::InitializeFromGraph( + const GraphDef& graph, + const absl::Span ephemeral_edges) { + return InitializeFromGraph(graph, ephemeral_edges, + /*ignore_control_edges=*/false); +} + +Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph, + bool ignore_control_edges) { + return InitializeFromGraph(graph, absl::Span(), + ignore_control_edges); +} + Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) { - return InitializeFromGraph(graph, absl::Span()); + return InitializeFromGraph(graph, absl::Span(), + /*ignore_control_edges*/ false); } bool GraphTopologyView::HasNode(const absl::string_view node_name) const { diff --git a/tensorflow/core/grappler/graph_topology_view.h b/tensorflow/core/grappler/graph_topology_view.h index c40d0093b90..d32e51928e0 100644 --- a/tensorflow/core/grappler/graph_topology_view.h +++ b/tensorflow/core/grappler/graph_topology_view.h @@ -55,8 +55,12 @@ class GraphTopologyView { // computing graph topology. Example: Tensorflow runtime allows concurrent // execution of dequeue/enqueue ops from the same queue resource, but we might // want to enforce ordering between them for the purpose of graph analysis. + Status InitializeFromGraph(const GraphDef& graph, + absl::Span ephemeral_edges, + bool ignore_control_edges); Status InitializeFromGraph(const GraphDef& graph, absl::Span ephemeral_edges); + Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges); Status InitializeFromGraph(const GraphDef& graph); bool is_initialized() const { return graph_ != nullptr; } diff --git a/tensorflow/core/grappler/graph_topology_view_test.cc b/tensorflow/core/grappler/graph_topology_view_test.cc index 36d3a2017cc..4d93eaa0b19 100644 --- a/tensorflow/core/grappler/graph_topology_view_test.cc +++ b/tensorflow/core/grappler/graph_topology_view_test.cc @@ -113,5 +113,41 @@ TEST_F(GraphTopologyViewTest, GraphWithALoop) { EXPECT_EQ(graph_view.GetFanout(3), Fanout({2})); } +TEST_F(GraphTopologyViewTest, GraphWithControls) { + const GraphDef graph = CreateGraph({ + {"a", {}}, // idx: 0 + {"b", {}}, // idx: 1 + {"c", {"a", "b", "^d"}}, // idx: 2 + {"d", {"a", "c"}}, // idx: 3 + }); + + { + GraphTopologyView graph_view; + TF_CHECK_OK(graph_view.InitializeFromGraph(graph)); + EXPECT_TRUE(graph_view.is_initialized()); + + using Fanin = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3})); + EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2})); + + using Fanout = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanout(2), Fanout({3})); + EXPECT_EQ(graph_view.GetFanout(3), Fanout({2})); + } + { + GraphTopologyView graph_view; + TF_CHECK_OK( + graph_view.InitializeFromGraph(graph, /*ignore_controls*/ true)); + EXPECT_TRUE(graph_view.is_initialized()); + using Fanin = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1})); + EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2})); + + using Fanout = absl::InlinedVector; + EXPECT_EQ(graph_view.GetFanout(2), Fanout({3})); + EXPECT_EQ(graph_view.GetFanout(3), Fanout({})); + } +} + } // namespace grappler } // namespace tensorflow