diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 88b1e975c13..63c58a0aede 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -318,11 +318,18 @@ class GraphViewInternal { protected: explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} + Status AddUniqueNode(NodeDefT* node) { + auto inserted = nodes_.emplace(node->name(), node); + return inserted.second + ? Status::OK() + : errors::InvalidArgument("Non unique node name detected: ", + node->name()); + } + + // TODO(ezhulenev): Remove this function. void AddUniqueNodeOrDie(NodeDefT* node) { - auto result = nodes_.emplace(node->name(), node); - // TODO(ezhulenev): Replace CHECK with factory method returning - // absl::StatusOr (when available). - CHECK(result.second) << "Non unique node name detected: " << node->name(); + Status st = AddUniqueNode(node); + CHECK(st.ok()) << st.error_message(); } // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 43dec07c1db..0f171f76563 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -215,6 +215,31 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) { return node_in_graph; } +Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { + if (subgraph.library().function_size() != 0) { + return errors::InvalidArgument( + "Can't add a subgraph with non-empty function library"); + } + + int node_size_before = graph()->node_size(); + + for (NodeDef& node : *subgraph.mutable_node()) { + auto* node_in_graph = graph()->add_node(); + *node_in_graph = std::move(node); + TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph)); + } + + // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that + // fanins actually exists in the graph, and there is already TODO for that. + + for (int i = node_size_before; i < graph()->node_size(); ++i) { + NodeDef* node = graph()->mutable_node(i); + AddAndDedupFanouts(node); + } + + return Status::OK(); +} + Status MutableGraphView::UpdateFanouts(absl::string_view from_node, absl::string_view to_node) { NodeDef* from_node_ptr = GetNode(from_node); diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index c62129bcadd..f08b5dde088 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -63,6 +63,16 @@ class MutableGraphView : public internal::GraphViewInternal { // node in graph. NodeDef* AddNode(NodeDef&& node); + // Adds all nodes from the `subgraph` to the underlying graph and updates the + // view. `subgraph` doesn't have to be a valid graph definition on it's own, + // it can have edges to the nodes that are not in it, however after adding + // it to the underlying graph, final graph must be valid. + // + // TODO(ezhulenev): Currently it will fail if subgraph has non-empty function + // library. Add support for adding new functions from the subgraph function + // library into the underlying graph. + Status AddSubgraph(GraphDef&& subgraph); + // Updates all fanouts (input ports fetching output tensors) from `from_node` // to the `to_node`, including control dependencies. // diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index acfaba5ddd3..ea88359e2ea 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -141,6 +141,58 @@ void CheckGraph(const MutableGraphView& mutable_graph) { } } +TEST(MutableGraphViewTest, AddSubgraph) { + GraphDef graph_def = test::function::GDef( + { + NDef("foo", "NotImportant", {}, {}), + NDef("bar", "NotImportant", {}, {}), + NDef("baz", "NotImportant", {"foo", "bar"}), + }, + /*funcs=*/{}); + MutableGraphView graph(&graph_def); + + // `s/bar` node has inputs that are valid only if we add subgraph into the + // original graph. + GraphDef subgraph = test::function::GDef( + { + NDef("s/n0", "NotImportant", {}, {}), + NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}), + }, + /*funcs=*/{}); + + TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph))); + + // Fanins and fanouts must be updated for the nodes of the original graph, and + // added subgraph. + CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"baz:1", "s/n1"}); + CheckNode(graph, "s/n1", "NotImportant", "", {}, {"bar", "s/n0"}, {}); + CheckGraph(graph); +} + +// TODO(ezhulenev): Add support for adding a subgraph and merging function +// libraries. +TEST(MutableGraphViewTest, AddSubgraphWithFunctionLibrary) { + GraphDef graph_def = test::function::GDef( + { + NDef("foo", "NotImportant", {}, {}), + NDef("bar", "NotImportant", {}, {}), + NDef("baz", "NotImportant", {"foo", "bar"}), + }, + /*funcs=*/{}); + MutableGraphView graph(&graph_def); + + FunctionDef x_times_two = test::function::XTimesTwo(); + GraphDef subgraph = test::function::GDef( + { + NDef("s/n0", "NotImportant", {}, {}), + NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}), + }, + /*funcs=*/{x_times_two}); + + Status status = graph.AddSubgraph(std::move(subgraph)); + EXPECT_FALSE(status.ok()); +} + TEST(MutableGraphViewTest, AddAndUpdateFanouts) { // Actual node.op() is not important in this test. GraphDef graph_def = test::function::GDef(