[Grappler] AddSubgraph in MutableGraphView
PiperOrigin-RevId: 230026030
This commit is contained in:
parent
1790a016fc
commit
e2968144a0
@ -318,11 +318,18 @@ class GraphViewInternal {
|
||||
protected:
|
||||
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
|
||||
|
||||
Status AddUniqueNode(NodeDefT* node) {
|
||||
auto inserted = nodes_.emplace(node->name(), node);
|
||||
return inserted.second
|
||||
? Status::OK()
|
||||
: errors::InvalidArgument("Non unique node name detected: ",
|
||||
node->name());
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Remove this function.
|
||||
void AddUniqueNodeOrDie(NodeDefT* node) {
|
||||
auto result = nodes_.emplace(node->name(), node);
|
||||
// TODO(ezhulenev): Replace CHECK with factory method returning
|
||||
// absl::StatusOr (when available).
|
||||
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
||||
Status st = AddUniqueNode(node);
|
||||
CHECK(st.ok()) << st.error_message();
|
||||
}
|
||||
|
||||
// TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
|
||||
|
@ -215,6 +215,31 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
||||
return node_in_graph;
|
||||
}
|
||||
|
||||
Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
|
||||
if (subgraph.library().function_size() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Can't add a subgraph with non-empty function library");
|
||||
}
|
||||
|
||||
int node_size_before = graph()->node_size();
|
||||
|
||||
for (NodeDef& node : *subgraph.mutable_node()) {
|
||||
auto* node_in_graph = graph()->add_node();
|
||||
*node_in_graph = std::move(node);
|
||||
TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
|
||||
}
|
||||
|
||||
// TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
|
||||
// fanins actually exists in the graph, and there is already TODO for that.
|
||||
|
||||
for (int i = node_size_before; i < graph()->node_size(); ++i) {
|
||||
NodeDef* node = graph()->mutable_node(i);
|
||||
AddAndDedupFanouts(node);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MutableGraphView::UpdateFanouts(absl::string_view from_node,
|
||||
absl::string_view to_node) {
|
||||
NodeDef* from_node_ptr = GetNode(from_node);
|
||||
|
@ -63,6 +63,16 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
// node in graph.
|
||||
NodeDef* AddNode(NodeDef&& node);
|
||||
|
||||
// Adds all nodes from the `subgraph` to the underlying graph and updates the
|
||||
// view. `subgraph` doesn't have to be a valid graph definition on it's own,
|
||||
// it can have edges to the nodes that are not in it, however after adding
|
||||
// it to the underlying graph, final graph must be valid.
|
||||
//
|
||||
// TODO(ezhulenev): Currently it will fail if subgraph has non-empty function
|
||||
// library. Add support for adding new functions from the subgraph function
|
||||
// library into the underlying graph.
|
||||
Status AddSubgraph(GraphDef&& subgraph);
|
||||
|
||||
// Updates all fanouts (input ports fetching output tensors) from `from_node`
|
||||
// to the `to_node`, including control dependencies.
|
||||
//
|
||||
|
@ -141,6 +141,58 @@ void CheckGraph(const MutableGraphView& mutable_graph) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddSubgraph) {
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{
|
||||
NDef("foo", "NotImportant", {}, {}),
|
||||
NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("baz", "NotImportant", {"foo", "bar"}),
|
||||
},
|
||||
/*funcs=*/{});
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
// `s/bar` node has inputs that are valid only if we add subgraph into the
|
||||
// original graph.
|
||||
GraphDef subgraph = test::function::GDef(
|
||||
{
|
||||
NDef("s/n0", "NotImportant", {}, {}),
|
||||
NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
|
||||
},
|
||||
/*funcs=*/{});
|
||||
|
||||
TF_EXPECT_OK(graph.AddSubgraph(std::move(subgraph)));
|
||||
|
||||
// Fanins and fanouts must be updated for the nodes of the original graph, and
|
||||
// added subgraph.
|
||||
CheckNode(graph, "bar", "NotImportant", "", {}, {}, {"baz:1", "s/n1"});
|
||||
CheckNode(graph, "s/n1", "NotImportant", "", {}, {"bar", "s/n0"}, {});
|
||||
CheckGraph(graph);
|
||||
}
|
||||
|
||||
// TODO(ezhulenev): Add support for adding a subgraph and merging function
|
||||
// libraries.
|
||||
TEST(MutableGraphViewTest, AddSubgraphWithFunctionLibrary) {
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{
|
||||
NDef("foo", "NotImportant", {}, {}),
|
||||
NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("baz", "NotImportant", {"foo", "bar"}),
|
||||
},
|
||||
/*funcs=*/{});
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
FunctionDef x_times_two = test::function::XTimesTwo();
|
||||
GraphDef subgraph = test::function::GDef(
|
||||
{
|
||||
NDef("s/n0", "NotImportant", {}, {}),
|
||||
NDef("s/n1", "NotImportant", {"bar", "s/n0"}, {}),
|
||||
},
|
||||
/*funcs=*/{x_times_two});
|
||||
|
||||
Status status = graph.AddSubgraph(std::move(subgraph));
|
||||
EXPECT_FALSE(status.ok());
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
|
Loading…
x
Reference in New Issue
Block a user