Add an option to GraphTopologyView to ignore control edges in the graph.
PiperOrigin-RevId: 237139933
This commit is contained in:
parent
31395d336e
commit
f193648584
@ -40,7 +40,8 @@ inline void SortAndRemoveDuplicates(T* v) {
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(
|
||||
const GraphDef& graph,
|
||||
const absl::Span<const GraphView::Edge> ephemeral_edges) {
|
||||
const absl::Span<const GraphView::Edge> ephemeral_edges,
|
||||
bool ignore_control_edges) {
|
||||
if (graph_ != nullptr) {
|
||||
return errors::InvalidArgument("GraphTopologyView is already initialized.");
|
||||
}
|
||||
@ -63,7 +64,6 @@ Status GraphTopologyView::InitializeFromGraph(
|
||||
for (const GraphView::Edge& edge : ephemeral_edges) {
|
||||
const auto src = node_name_to_index_.find(edge.src.node->name());
|
||||
const bool valid_src = src != node_name_to_index_.end();
|
||||
|
||||
if (!valid_src) {
|
||||
const string error_message =
|
||||
absl::StrCat("Non-existent src node: ", edge.src.node->name());
|
||||
@ -90,6 +90,9 @@ Status GraphTopologyView::InitializeFromGraph(
|
||||
if (valid_dst && valid_src) {
|
||||
const int src_idx = src->second;
|
||||
const int dst_idx = dst->second;
|
||||
if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) {
|
||||
continue;
|
||||
}
|
||||
fanins_[dst_idx].push_back(src_idx);
|
||||
fanouts_[src_idx].push_back(dst_idx);
|
||||
}
|
||||
@ -102,6 +105,9 @@ Status GraphTopologyView::InitializeFromGraph(
|
||||
|
||||
for (const string& input : node.input()) {
|
||||
TensorId tensor = ParseTensorName(input);
|
||||
if (ignore_control_edges && IsTensorIdControl(tensor)) {
|
||||
continue;
|
||||
}
|
||||
const auto it = node_name_to_index_.find(tensor.node());
|
||||
const bool valid_input = it != node_name_to_index_.end();
|
||||
|
||||
@ -134,8 +140,22 @@ Status GraphTopologyView::InitializeFromGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(
|
||||
const GraphDef& graph,
|
||||
const absl::Span<const GraphView::Edge> ephemeral_edges) {
|
||||
return InitializeFromGraph(graph, ephemeral_edges,
|
||||
/*ignore_control_edges=*/false);
|
||||
}
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph,
|
||||
bool ignore_control_edges) {
|
||||
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
|
||||
ignore_control_edges);
|
||||
}
|
||||
|
||||
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
|
||||
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>());
|
||||
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
|
||||
/*ignore_control_edges*/ false);
|
||||
}
|
||||
|
||||
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
|
||||
|
@ -55,8 +55,12 @@ class GraphTopologyView {
|
||||
// computing graph topology. Example: Tensorflow runtime allows concurrent
|
||||
// execution of dequeue/enqueue ops from the same queue resource, but we might
|
||||
// want to enforce ordering between them for the purpose of graph analysis.
|
||||
Status InitializeFromGraph(const GraphDef& graph,
|
||||
absl::Span<const GraphView::Edge> ephemeral_edges,
|
||||
bool ignore_control_edges);
|
||||
Status InitializeFromGraph(const GraphDef& graph,
|
||||
absl::Span<const GraphView::Edge> ephemeral_edges);
|
||||
Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges);
|
||||
Status InitializeFromGraph(const GraphDef& graph);
|
||||
|
||||
bool is_initialized() const { return graph_ != nullptr; }
|
||||
|
@ -113,5 +113,41 @@ TEST_F(GraphTopologyViewTest, GraphWithALoop) {
|
||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
||||
}
|
||||
|
||||
TEST_F(GraphTopologyViewTest, GraphWithControls) {
|
||||
const GraphDef graph = CreateGraph({
|
||||
{"a", {}}, // idx: 0
|
||||
{"b", {}}, // idx: 1
|
||||
{"c", {"a", "b", "^d"}}, // idx: 2
|
||||
{"d", {"a", "c"}}, // idx: 3
|
||||
});
|
||||
|
||||
{
|
||||
GraphTopologyView graph_view;
|
||||
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
|
||||
EXPECT_TRUE(graph_view.is_initialized());
|
||||
|
||||
using Fanin = absl::InlinedVector<int, 4>;
|
||||
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3}));
|
||||
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||
|
||||
using Fanout = absl::InlinedVector<int, 2>;
|
||||
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
||||
}
|
||||
{
|
||||
GraphTopologyView graph_view;
|
||||
TF_CHECK_OK(
|
||||
graph_view.InitializeFromGraph(graph, /*ignore_controls*/ true));
|
||||
EXPECT_TRUE(graph_view.is_initialized());
|
||||
using Fanin = absl::InlinedVector<int, 4>;
|
||||
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1}));
|
||||
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||
|
||||
using Fanout = absl::InlinedVector<int, 2>;
|
||||
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout({}));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user