Add an option to GraphTopologyView to ignore control edges in the graph.
PiperOrigin-RevId: 237139933
This commit is contained in:
parent
31395d336e
commit
f193648584
@ -40,7 +40,8 @@ inline void SortAndRemoveDuplicates(T* v) {
|
|||||||
|
|
||||||
Status GraphTopologyView::InitializeFromGraph(
|
Status GraphTopologyView::InitializeFromGraph(
|
||||||
const GraphDef& graph,
|
const GraphDef& graph,
|
||||||
const absl::Span<const GraphView::Edge> ephemeral_edges) {
|
const absl::Span<const GraphView::Edge> ephemeral_edges,
|
||||||
|
bool ignore_control_edges) {
|
||||||
if (graph_ != nullptr) {
|
if (graph_ != nullptr) {
|
||||||
return errors::InvalidArgument("GraphTopologyView is already initialized.");
|
return errors::InvalidArgument("GraphTopologyView is already initialized.");
|
||||||
}
|
}
|
||||||
@ -63,7 +64,6 @@ Status GraphTopologyView::InitializeFromGraph(
|
|||||||
for (const GraphView::Edge& edge : ephemeral_edges) {
|
for (const GraphView::Edge& edge : ephemeral_edges) {
|
||||||
const auto src = node_name_to_index_.find(edge.src.node->name());
|
const auto src = node_name_to_index_.find(edge.src.node->name());
|
||||||
const bool valid_src = src != node_name_to_index_.end();
|
const bool valid_src = src != node_name_to_index_.end();
|
||||||
|
|
||||||
if (!valid_src) {
|
if (!valid_src) {
|
||||||
const string error_message =
|
const string error_message =
|
||||||
absl::StrCat("Non-existent src node: ", edge.src.node->name());
|
absl::StrCat("Non-existent src node: ", edge.src.node->name());
|
||||||
@ -90,6 +90,9 @@ Status GraphTopologyView::InitializeFromGraph(
|
|||||||
if (valid_dst && valid_src) {
|
if (valid_dst && valid_src) {
|
||||||
const int src_idx = src->second;
|
const int src_idx = src->second;
|
||||||
const int dst_idx = dst->second;
|
const int dst_idx = dst->second;
|
||||||
|
if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
fanins_[dst_idx].push_back(src_idx);
|
fanins_[dst_idx].push_back(src_idx);
|
||||||
fanouts_[src_idx].push_back(dst_idx);
|
fanouts_[src_idx].push_back(dst_idx);
|
||||||
}
|
}
|
||||||
@ -102,6 +105,9 @@ Status GraphTopologyView::InitializeFromGraph(
|
|||||||
|
|
||||||
for (const string& input : node.input()) {
|
for (const string& input : node.input()) {
|
||||||
TensorId tensor = ParseTensorName(input);
|
TensorId tensor = ParseTensorName(input);
|
||||||
|
if (ignore_control_edges && IsTensorIdControl(tensor)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const auto it = node_name_to_index_.find(tensor.node());
|
const auto it = node_name_to_index_.find(tensor.node());
|
||||||
const bool valid_input = it != node_name_to_index_.end();
|
const bool valid_input = it != node_name_to_index_.end();
|
||||||
|
|
||||||
@ -134,8 +140,22 @@ Status GraphTopologyView::InitializeFromGraph(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status GraphTopologyView::InitializeFromGraph(
|
||||||
|
const GraphDef& graph,
|
||||||
|
const absl::Span<const GraphView::Edge> ephemeral_edges) {
|
||||||
|
return InitializeFromGraph(graph, ephemeral_edges,
|
||||||
|
/*ignore_control_edges=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph,
|
||||||
|
bool ignore_control_edges) {
|
||||||
|
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
|
||||||
|
ignore_control_edges);
|
||||||
|
}
|
||||||
|
|
||||||
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
|
Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
|
||||||
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>());
|
return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
|
||||||
|
/*ignore_control_edges*/ false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
|
bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
|
||||||
|
@ -55,8 +55,12 @@ class GraphTopologyView {
|
|||||||
// computing graph topology. Example: Tensorflow runtime allows concurrent
|
// computing graph topology. Example: Tensorflow runtime allows concurrent
|
||||||
// execution of dequeue/enqueue ops from the same queue resource, but we might
|
// execution of dequeue/enqueue ops from the same queue resource, but we might
|
||||||
// want to enforce ordering between them for the purpose of graph analysis.
|
// want to enforce ordering between them for the purpose of graph analysis.
|
||||||
|
Status InitializeFromGraph(const GraphDef& graph,
|
||||||
|
absl::Span<const GraphView::Edge> ephemeral_edges,
|
||||||
|
bool ignore_control_edges);
|
||||||
Status InitializeFromGraph(const GraphDef& graph,
|
Status InitializeFromGraph(const GraphDef& graph,
|
||||||
absl::Span<const GraphView::Edge> ephemeral_edges);
|
absl::Span<const GraphView::Edge> ephemeral_edges);
|
||||||
|
Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges);
|
||||||
Status InitializeFromGraph(const GraphDef& graph);
|
Status InitializeFromGraph(const GraphDef& graph);
|
||||||
|
|
||||||
bool is_initialized() const { return graph_ != nullptr; }
|
bool is_initialized() const { return graph_ != nullptr; }
|
||||||
|
@ -113,5 +113,41 @@ TEST_F(GraphTopologyViewTest, GraphWithALoop) {
|
|||||||
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphTopologyViewTest, GraphWithControls) {
|
||||||
|
const GraphDef graph = CreateGraph({
|
||||||
|
{"a", {}}, // idx: 0
|
||||||
|
{"b", {}}, // idx: 1
|
||||||
|
{"c", {"a", "b", "^d"}}, // idx: 2
|
||||||
|
{"d", {"a", "c"}}, // idx: 3
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
GraphTopologyView graph_view;
|
||||||
|
TF_CHECK_OK(graph_view.InitializeFromGraph(graph));
|
||||||
|
EXPECT_TRUE(graph_view.is_initialized());
|
||||||
|
|
||||||
|
using Fanin = absl::InlinedVector<int, 4>;
|
||||||
|
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1, 3}));
|
||||||
|
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||||
|
|
||||||
|
using Fanout = absl::InlinedVector<int, 2>;
|
||||||
|
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||||
|
EXPECT_EQ(graph_view.GetFanout(3), Fanout({2}));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
GraphTopologyView graph_view;
|
||||||
|
TF_CHECK_OK(
|
||||||
|
graph_view.InitializeFromGraph(graph, /*ignore_controls*/ true));
|
||||||
|
EXPECT_TRUE(graph_view.is_initialized());
|
||||||
|
using Fanin = absl::InlinedVector<int, 4>;
|
||||||
|
EXPECT_EQ(graph_view.GetFanin(2), Fanin({0, 1}));
|
||||||
|
EXPECT_EQ(graph_view.GetFanin(3), Fanin({0, 2}));
|
||||||
|
|
||||||
|
using Fanout = absl::InlinedVector<int, 2>;
|
||||||
|
EXPECT_EQ(graph_view.GetFanout(2), Fanout({3}));
|
||||||
|
EXPECT_EQ(graph_view.GetFanout(3), Fanout({}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user