[Grappler] AddSubgraph in MutableGraphView
PiperOrigin-RevId: 230026030
This commit is contained in:
parent
1790a016fc
commit
e2968144a0
@ -318,11 +318,18 @@ class GraphViewInternal {
|
|||||||
protected:
|
protected:
|
||||||
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
|
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
|
||||||
|
|
||||||
|
Status AddUniqueNode(NodeDefT* node) {
|
||||||
|
auto inserted = nodes_.emplace(node->name(), node);
|
||||||
|
return inserted.second
|
||||||
|
? Status::OK()
|
||||||
|
: errors::InvalidArgument("Non unique node name detected: ",
|
||||||
|
node->name());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ezhulenev): Remove this function.
|
||||||
void AddUniqueNodeOrDie(NodeDefT* node) {
|
void AddUniqueNodeOrDie(NodeDefT* node) {
|
||||||
auto result = nodes_.emplace(node->name(), node);
|
Status st = AddUniqueNode(node);
|
||||||
// TODO(ezhulenev): Replace CHECK with factory method returning
|
CHECK(st.ok()) << st.error_message();
|
||||||
// absl::StatusOr (when available).
|
|
||||||
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
|
// TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
|
||||||
|
@ -215,6 +215,31 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
|||||||
return node_in_graph;
|
return node_in_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
|
||||||
|
if (subgraph.library().function_size() != 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Can't add a subgraph with non-empty function library");
|
||||||
|
}
|
||||||
|
|
||||||
|
int node_size_before = graph()->node_size();
|
||||||
|
|
||||||
|
for (NodeDef& node : *subgraph.mutable_node()) {
|
||||||
|
auto* node_in_graph = graph()->add_node();
|
||||||
|
*node_in_graph = std::move(node);
|
||||||
|
TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
|
||||||
|
// fanins actually exists in the graph, and there is already TODO for that.
|
||||||
|
|
||||||
|
for (int i = node_size_before; i < graph()->node_size(); ++i) {
|
||||||
|
NodeDef* node = graph()->mutable_node(i);
|
||||||
|
AddAndDedupFanouts(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status MutableGraphView::UpdateFanouts(absl::string_view from_node,
|
Status MutableGraphView::UpdateFanouts(absl::string_view from_node,
|
||||||
absl::string_view to_node) {
|
absl::string_view to_node) {
|
||||||
NodeDef* from_node_ptr = GetNode(from_node);
|
NodeDef* from_node_ptr = GetNode(from_node);
|
||||||
|
@ -63,6 +63,16 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
|||||||
// node in graph.
|
// node in graph.
|
||||||
NodeDef* AddNode(NodeDef&& node);
|
NodeDef* AddNode(NodeDef&& node);
|
||||||
|
|
||||||
|
// Adds all nodes from the `subgraph` to the underlying graph and updates the
|
||||||
|
// view. `subgraph` doesn't have to be a valid graph definition on it's own,
|
||||||
|
// it can have edges to the nodes that are not in it, however after adding
|
||||||
|
// it to the underlying graph, final graph must be valid.
|
||||||
|
//
|
||||||
|
// TODO(ezhulenev): Currently it will fail if subgraph has non-empty function
|
||||||
|
// library. Add support for adding new functions from the subgraph function
|
||||||
|
// library into the underlying graph.
|
||||||
|
Status AddSubgraph(GraphDef&& subgraph);
|
||||||
|
|
||||||
// Updates all fanouts (input ports fetching output tensors) from `from_node`
|
// Updates all fanouts (input ports fetching output tensors) from `from_node`
|
||||||
// to the `to_node`, including control dependencies.
|
// to the `to_node`, including control dependencies.
|
||||||
//
|
//
|
||||||
|
@ -141,6 +141,58 @@ void CheckGraph(const MutableGraphView& mutable_graph) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MutableGraphViewTest, AddSubgraph) {
|
||||||
|
GraphDef graph_def = test::function::GDef(
|
||||||
|
{
|
||||||
|
NDef("foo", "NotImportant", {}, {}),
|
||||||
|
NDef("bar", "NotImportant", {}, {}),
|
||||||
|
NDef("baz", "NotImportant", {"foo", "bar"}),
|
||||||
|
},
|
||||||
|
/*funcs=*/{});
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
|
||||||
|
// `s/bar` node has inputs that are valid only if we add subgraph into the
|
||||||
|
// original graph.
|
||||||
|
GraphDef subgraph = test::function::GDef(
|
||||||
|
{
|
||||||
|
NDef("s/n0", "NotImportant", {}, {}),
|
||||||
|
NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
|
||||||
|
},
|
||||||
|
/*funcs=*/{});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
|
||||||
|
|
||||||
|
// Fanins and fanouts must be updated for the nodes of the original graph, and
|
||||||
|
// added subgraph.
|
||||||
|
CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"baz:1", "s/n1"});
|
||||||
|
CheckNode(graph, "s/n1", "NotImportant", "", {}, {"bar", "s/n0"}, {});
|
||||||
|
CheckGraph(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(ezhulenev): Add support for adding a subgraph and merging function
|
||||||
|
// libraries.
|
||||||
|
TEST(MutableGraphViewTest, AddSubgraphWithFunctionLibrary) {
|
||||||
|
GraphDef graph_def = test::function::GDef(
|
||||||
|
{
|
||||||
|
NDef("foo", "NotImportant", {}, {}),
|
||||||
|
NDef("bar", "NotImportant", {}, {}),
|
||||||
|
NDef("baz", "NotImportant", {"foo", "bar"}),
|
||||||
|
},
|
||||||
|
/*funcs=*/{});
|
||||||
|
MutableGraphView graph(&graph_def);
|
||||||
|
|
||||||
|
FunctionDef x_times_two = test::function::XTimesTwo();
|
||||||
|
GraphDef subgraph = test::function::GDef(
|
||||||
|
{
|
||||||
|
NDef("s/n0", "NotImportant", {}, {}),
|
||||||
|
NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
|
||||||
|
},
|
||||||
|
/*funcs=*/{x_times_two});
|
||||||
|
|
||||||
|
Status status = graph.AddSubgraph(std::move(subgraph));
|
||||||
|
EXPECT_FALSE(status.ok());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
|
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
|
||||||
// Actual node.op() is not important in this test.
|
// Actual node.op() is not important in this test.
|
||||||
GraphDef graph_def = test::function::GDef(
|
GraphDef graph_def = test::function::GDef(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user