Add an option to GraphTopologyView to ignore control edges in the graph.

PiperOrigin-RevId: 237139933
This commit is contained in:
A. Unique TensorFlower 2019-03-06 16:08:57 -08:00 committed by TensorFlower Gardener
parent 31395d336e
commit f193648584
3 changed files with 63 additions and 3 deletions

View File

@ -40,7 +40,8 @@ inline void SortAndRemoveDuplicates(T* v) {
Status GraphTopologyView::InitializeFromGraph(
const GraphDef& graph,
const absl::Span<const GraphView::Edge> ephemeral_edges) {
const absl::Span<const GraphView::Edge> ephemeral_edges,
bool ignore_control_edges) {
if (graph_ != nullptr) {
return errors::InvalidArgument("GraphTopologyView is already initialized.");
}
@ -63,7 +64,6 @@ Status GraphTopologyView::InitializeFromGraph(
for (const GraphView::Edge& edge : ephemeral_edges) {
const auto src = node_name_to_index_.find(edge.src.node->name());
const bool valid_src = src != node_name_to_index_.end();
if (!valid_src) {
const string error_message =
absl::StrCat("Non-existent src node: ", edge.src.node->name());
@ -90,6 +90,9 @@ Status GraphTopologyView::InitializeFromGraph(
if (valid_dst && valid_src) {
const int src_idx = src->second;
const int dst_idx = dst->second;
if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) {
continue;
}
fanins_[dst_idx].push_back(src_idx);
fanouts_[src_idx].push_back(dst_idx);
}
@ -102,6 +105,9 @@ Status GraphTopologyView::InitializeFromGraph(
for (const string& input : node.input()) {
TensorId tensor = ParseTensorName(input);
if (ignore_control_edges && IsTensorIdControl(tensor)) {
continue;
}
const auto it = node_name_to_index_.find(tensor.node());
const bool valid_input = it != node_name_to_index_.end();
@ -134,8 +140,22 @@ Status GraphTopologyView::InitializeFromGraph(
return Status::OK();
}
Status GraphTopologyView::InitializeFromGraph(
const GraphDef& graph,
const absl::Span<const GraphView::Edge> ephemeral_edges) {
return InitializeFromGraph(graph, ephemeral_edges,
/*ignore_control_edges=*/false);
}
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph,
bool ignore_control_edges) {
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
ignore_control_edges);
}
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>());
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
/*ignore_control_edges*/ false);
}
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {

View File

@ -55,8 +55,12 @@ class GraphTopologyView {
// computing graph topology. Example: Tensorflow runtime allows concurrent
// execution of dequeue/enqueue ops from the same queue resource, but we might
// want to enforce ordering between them for the purpose of graph analysis.
Status InitializeFromGraph(const GraphDef& graph,
absl::Span<const GraphView::Edge> ephemeral_edges,
bool ignore_control_edges);
Status InitializeFromGraph(const GraphDef& graph,
absl::Span<const GraphView::Edge> ephemeral_edges);
Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges);
Status InitializeFromGraph(const GraphDef& graph);
bool is_initialized() const { return graph_ != nullptr; }

View File

@ -113,5 +113,41 @@ TEST_F(GraphTopologyViewTest, GraphWithALoop) {
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
}
TEST_F(GraphTopologyViewTest, GraphWithControls) {
const GraphDef graph = CreateGraph({
{"a", {}}, // idx: 0
{"b", {}}, // idx: 1
{"c", {"a", "b", "^d"}}, // idx: 2
{"d", {"a", "c"}}, // idx: 3
});
{
GraphTopologyView graph_view;
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
EXPECT_TRUE(graph_view.is_initialized());
using Fanin = absl::InlinedVector<int, 4>;
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3}));
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
using Fanout = absl::InlinedVector<int, 2>;
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
}
{
GraphTopologyView graph_view;
TF_CHECK_OK(
graph_view.InitializeFromGraph(graph, /*ignore_controls*/ true));
EXPECT_TRUE(graph_view.is_initialized());
using Fanin = absl::InlinedVector<int, 4>;
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1}));
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
using Fanout = absl::InlinedVector<int, 2>;
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
EXPECT_EQ(graph_view.GetFanout(3), Fanout({}));
}
}
} // namespace grappler
} // namespace tensorflow