[Grappler] MutableGraphView: Update fanouts.
Current implementation if ReplaceInputs is half done/broken. Add UpdateFanouts functions that properly takes care of control dependencies, and updates internal state. PiperOrigin-RevId: 219884693
This commit is contained in:
parent
2c923299cc
commit
8a91e6adc6
@ -170,8 +170,10 @@ cc_library(
|
||||
":graph_view",
|
||||
":grappler_item",
|
||||
":utils",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -184,6 +186,7 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
@ -134,7 +134,7 @@ class GraphViewInternal {
|
||||
// of an output (resp. input) port.
|
||||
const absl::flat_hash_set<InputPort>& GetFanout(
|
||||
const OutputPort& port) const {
|
||||
return gtl::FindWithDefault(fanouts_, port, empty_set_);
|
||||
return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
|
||||
@ -173,7 +173,7 @@ class GraphViewInternal {
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlled_nodes ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||
gtl::FindWithDefault(max_regular_output_port_, port.node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
@ -220,7 +220,7 @@ class GraphViewInternal {
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlling_nodes ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||
gtl::FindWithDefault(max_regular_output_port_, port.node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
@ -241,7 +241,7 @@ class GraphViewInternal {
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlled_edges ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
|
||||
gtl::FindWithDefault(max_regular_output_port_, &node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
@ -290,29 +290,42 @@ class GraphViewInternal {
|
||||
if (output.port_id < 0) {
|
||||
fanouts_[output].emplace(node, -1);
|
||||
} else {
|
||||
num_regular_outputs_[output.node] =
|
||||
std::max(num_regular_outputs_[output.node], output.port_id);
|
||||
max_regular_output_port_[output.node] =
|
||||
std::max(max_regular_output_port_[output.node], output.port_id);
|
||||
fanouts_[output].emplace(node, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Access to the mutable internal state for MutableGraphView.
|
||||
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
|
||||
return &nodes_;
|
||||
absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; }
|
||||
|
||||
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() {
|
||||
return fanouts_;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
|
||||
mutable_fanouts() {
|
||||
return &fanouts_;
|
||||
absl::flat_hash_map<const NodeDef*, int>& max_regular_output_port() {
|
||||
return max_regular_output_port_;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphDefT* graph_; // must outlive the graph view
|
||||
|
||||
// A mapping from the node name to the node itself.
|
||||
absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
|
||||
absl::flat_hash_set<InputPort> empty_set_;
|
||||
|
||||
// A mapping from the output port to all inputs that read from it.
|
||||
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
|
||||
std::unordered_map<NodeDefT*, int> num_regular_outputs_;
|
||||
|
||||
// Keep a maximum index of tensor fetched from the node. It doesn't guarantee
|
||||
// that all tensors in the [0, max_regular_output_port] range are actually
|
||||
// fetched by other nodes.
|
||||
absl::flat_hash_map<const NodeDef*, int> max_regular_output_port_;
|
||||
|
||||
// If the node has no fanouts at given output port (output tensor consumers)
|
||||
// we return a reference to this set from `GetFanout` (we can't construct new
|
||||
// empty set every time, because we need a non-dangling reference).
|
||||
absl::flat_hash_set<InputPort> fanout_not_found_value_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -47,52 +50,137 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
||||
return node_in_graph;
|
||||
}
|
||||
|
||||
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
||||
const int output_port_id) {
|
||||
auto* node_in_graph = graph()->add_node();
|
||||
*node_in_graph = std::move(node);
|
||||
|
||||
AddUniqueNodeOrDie(node_in_graph);
|
||||
|
||||
// replace input for the output nodes of `input_node` with `node`
|
||||
ReplaceInput(input_node, *node_in_graph, output_port_id);
|
||||
|
||||
AddFanouts(node_in_graph);
|
||||
return node_in_graph;
|
||||
void MutableGraphView::UpdateFanouts(absl::string_view from_node,
|
||||
absl::string_view to_node) {
|
||||
NodeDef* from_node_ptr = GetNode(from_node);
|
||||
NodeDef* to_node_ptr = GetNode(to_node);
|
||||
if (from_node_ptr && to_node_ptr) {
|
||||
UpdateFanouts(from_node_ptr, to_node_ptr);
|
||||
} else if (!from_node_ptr) {
|
||||
LOG(WARNING) << absl::Substitute(
|
||||
"Can't update fanouts from '$0' to '$1', from node was not found.",
|
||||
from_node, to_node);
|
||||
} else {
|
||||
LOG(WARNING) << absl::Substitute(
|
||||
"Can't update fanouts from '$0' to '$1', to node was not found.",
|
||||
from_node, to_node);
|
||||
}
|
||||
}
|
||||
|
||||
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
||||
const NodeDef& new_input,
|
||||
const int output_port_id) {
|
||||
OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
|
||||
auto fanout = GetFanout(output_port);
|
||||
for (auto& input_port : fanout) {
|
||||
input_port.node->set_input(input_port.port_id, new_input.name());
|
||||
AddFanouts(input_port.node);
|
||||
void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
|
||||
VLOG(0) << absl::Substitute("Update fanouts from '$0' to '$1'.",
|
||||
from_node->name(), to_node->name());
|
||||
|
||||
// Update internal state with the new output_port->input_port edge.
|
||||
const auto add_edge = [this](const OutputPort& output_port,
|
||||
const InputPort& input_port) {
|
||||
fanouts()[output_port].insert(input_port);
|
||||
};
|
||||
|
||||
// Remove invalidated edge from the internal state.
|
||||
const auto remove_edge = [this](const OutputPort& output_port,
|
||||
const InputPort& input_port) {
|
||||
fanouts()[output_port].erase(input_port);
|
||||
};
|
||||
|
||||
// First we update regular fanouts. For the regular fanouts
|
||||
// `input_port:port_id` is the input index in NodeDef.
|
||||
|
||||
auto regular_edges =
|
||||
GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
|
||||
|
||||
// Maximum index of the `from_node` output tensor that is still used as an
|
||||
// input to some other node.
|
||||
int keep_max_regular_output_port = -1;
|
||||
|
||||
for (const Edge& edge : regular_edges) {
|
||||
const OutputPort output_port = edge.src;
|
||||
const InputPort input_port = edge.dst;
|
||||
|
||||
// If the `to_node` reads from the `from_node`, skip this edge (see
|
||||
// AddAndUpdateFanoutsWithoutSelfLoops test for an example).
|
||||
if (input_port.node == to_node) {
|
||||
keep_max_regular_output_port =
|
||||
std::max(keep_max_regular_output_port, input_port.port_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update input at destination node.
|
||||
input_port.node->set_input(
|
||||
input_port.port_id,
|
||||
output_port.port_id == 0
|
||||
? to_node->name()
|
||||
: absl::StrCat(to_node->name(), ":", output_port.port_id));
|
||||
|
||||
// Remove old edge between the `from_node` and the fanout node.
|
||||
remove_edge(output_port, input_port);
|
||||
// Add an edge between the `to_node` and new fanout node.
|
||||
add_edge(OutputPort(to_node, output_port.port_id), input_port);
|
||||
}
|
||||
|
||||
// For the control fanouts we do not know the input index in a NodeDef,
|
||||
// so we have to traverse all control inputs.
|
||||
|
||||
auto control_fanouts =
|
||||
GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
|
||||
if (control_fanouts.empty()) return;
|
||||
|
||||
const string from_control_input = absl::StrCat("^", from_node->name());
|
||||
const string to_control_input = absl::StrCat("^", to_node->name());
|
||||
|
||||
for (const InputPort& control_port : control_fanouts) {
|
||||
// Node can't be control dependency of itself.
|
||||
if (control_port.node == to_node) continue;
|
||||
|
||||
// Find and update input corresponding to control dependency.
|
||||
NodeDef* node = control_port.node;
|
||||
for (int i = node->input_size() - 1; i >= 0; --i) {
|
||||
const string& input = node->input(i);
|
||||
if (!IsControlInput(input)) break; // we reached regular inputs
|
||||
if (input == from_control_input) {
|
||||
node->set_input(i, to_control_input);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove old edge between the `from_node` and the fanout node.
|
||||
remove_edge(OutputPort(from_node, Graph::kControlSlot), control_port);
|
||||
// Add an edge between the `to_node` and new fanout node.
|
||||
add_edge(OutputPort(to_node, Graph::kControlSlot), control_port);
|
||||
}
|
||||
|
||||
// Because we update all regular fanouts of `from_node`, we can just copy
|
||||
// the value `num_regular_outputs`.
|
||||
max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
|
||||
|
||||
// Check if all fanouts were updated to read from the `to_node`.
|
||||
if (keep_max_regular_output_port >= 0) {
|
||||
max_regular_output_port()[from_node] = keep_max_regular_output_port;
|
||||
} else {
|
||||
max_regular_output_port().erase(from_node);
|
||||
}
|
||||
}
|
||||
|
||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
|
||||
RemoveFanouts(nodes().at(node_name_to_delete));
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
mutable_nodes()->erase(node_name_to_delete);
|
||||
nodes().erase(node_name_to_delete);
|
||||
EraseNodesFromGraph(nodes_to_delete, graph());
|
||||
}
|
||||
|
||||
void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node->input(i));
|
||||
OutputPort fanin((*mutable_nodes())[tensor_id.node()], tensor_id.index());
|
||||
void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) {
|
||||
for (int i = 0; i < deleted_node->input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(deleted_node->input(i));
|
||||
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
|
||||
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
input.node = deleted_node;
|
||||
if (tensor_id.index() < 0)
|
||||
input.port_id = -1;
|
||||
input.port_id = Graph::kControlSlot;
|
||||
else
|
||||
input.port_id = i;
|
||||
|
||||
(*mutable_fanouts())[fanin].erase(input);
|
||||
fanouts()[fanin].erase(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -44,31 +44,44 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
using GraphViewInternal::GetRegularFanin;
|
||||
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
|
||||
|
||||
// Adds a new node to graph and updates the view.
|
||||
// Adds a new node to graph and updates the view. Returns a pointer to the
|
||||
// node in graph.
|
||||
NodeDef* AddNode(NodeDef&& node);
|
||||
|
||||
// Inserts a new node to the graph after `input` node and updates the view.
|
||||
// This adds `node` to the graph and replaces the input for the output
|
||||
// nodes of `input` with a port `output_port_id` with the new node.
|
||||
NodeDef* InsertNode(const NodeDef& input, NodeDef&& node,
|
||||
int output_port_id = 0);
|
||||
|
||||
// Replaces the input for the output nodes of 'old_input' with a port
|
||||
// `output_port_id` with 'new_input'.
|
||||
// Updates all fanouts (input ports fetching output tensors) from `from_node`
|
||||
// to the `to_node`, including control dependencies.
|
||||
//
|
||||
// E.g: We have 2 nodes that use 'bar' node outputs as inputs:
|
||||
// foo(bar:0, bar:1), foo2(other:0, bar:0)
|
||||
// Calling ReplaceInput(bar, new, 0) changes every occurrence of bar:0 for
|
||||
// new:0. Result:
|
||||
// foo(new:0, bar:1), foo2(other:0, new:0)
|
||||
void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
|
||||
int output_port_id = 0);
|
||||
// Example: We have 2 nodes that use `bar` node output tensors as inputs:
|
||||
// 1. foo1(bar:0, bar:1, other:0, ^bar)
|
||||
// 2. foo2(bar:1, other:1)
|
||||
//
|
||||
// After calling ForwardOutputs(bar, new_bar):
|
||||
// 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar)
|
||||
// 2. foo2(new_bar:1, other:1)
|
||||
void UpdateFanouts(absl::string_view from_node, absl::string_view to_node);
|
||||
|
||||
// Deletes nodes from the graph.
|
||||
void DeleteNodes(const std::set<string>& nodes_to_delete);
|
||||
|
||||
private:
|
||||
void RemoveFanouts(NodeDef* node);
|
||||
// Updates all fanouts (input ports fetching output tensors) from `from_node`
|
||||
// to the `to_node`, including control dependencies.
|
||||
//
|
||||
// Example: We have 2 nodes that use `bar` node output tensors as inputs:
|
||||
// 1. foo1(bar:0, bar:1, other:0, ^bar)
|
||||
// 2. foo2(bar:1, other:1)
|
||||
//
|
||||
// After calling ForwardOutputs(bar, new_bar):
|
||||
// 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar)
|
||||
// 2. foo2(new_bar:1, other:1)
|
||||
//
|
||||
// IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the
|
||||
// behavior is undefined.
|
||||
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
|
||||
|
||||
// Remove fanouts of the deleted node from internal state (including control
|
||||
// dependencies).
|
||||
void RemoveFanouts(NodeDef* deleted_node);
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
@ -23,104 +24,122 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
bool FindChildWithName(const MutableGraphView& graph,
|
||||
const string& output_port_name,
|
||||
const string& input_name) {
|
||||
MutableGraphView::OutputPort output_port =
|
||||
graph.GetOutputPort(output_port_name, 0);
|
||||
auto fanout = graph.GetFanout(output_port);
|
||||
for (auto& input_port : fanout) {
|
||||
if (input_port.node->name() == input_name) return true;
|
||||
}
|
||||
return false;
|
||||
using ::tensorflow::test::function::NDef;
|
||||
|
||||
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("other", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
|
||||
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NotImportant", {}, {}));
|
||||
NodeDef* bar = graph.GetNode("bar");
|
||||
|
||||
graph.UpdateFanouts(bar->name(), new_bar->name());
|
||||
|
||||
// Fanout nodes must have their inputs updated.
|
||||
NodeDef* foo_1 = graph.GetNode("foo_1");
|
||||
ASSERT_NE(foo_1, nullptr);
|
||||
ASSERT_EQ(foo_1->input_size(), 4);
|
||||
EXPECT_EQ(foo_1->input(0), "new_bar");
|
||||
EXPECT_EQ(foo_1->input(1), "other");
|
||||
EXPECT_EQ(foo_1->input(2), "new_bar:1");
|
||||
EXPECT_EQ(foo_1->input(3), "^new_bar");
|
||||
|
||||
NodeDef* foo_2 = graph.GetNode("foo_2");
|
||||
ASSERT_NE(foo_2, nullptr);
|
||||
ASSERT_EQ(foo_2->input_size(), 3);
|
||||
EXPECT_EQ(foo_2->input(0), "other:1");
|
||||
EXPECT_EQ(foo_2->input(1), "new_bar:2");
|
||||
EXPECT_EQ(foo_2->input(2), "^new_bar");
|
||||
|
||||
// And fanouts mapping must be also updated for both nodes.
|
||||
bool include_control_fanouts = true;
|
||||
auto old_node_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
|
||||
auto new_node_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts);
|
||||
|
||||
EXPECT_TRUE(old_node_fanouts.empty());
|
||||
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 0)), 1);
|
||||
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 2)), 1);
|
||||
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, -1)), 1);
|
||||
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1);
|
||||
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1);
|
||||
}
|
||||
|
||||
TrivialTestGraphInputYielder SimpleGraph() {
|
||||
// This outputs simple graph like:
|
||||
// x
|
||||
// / \
|
||||
// Square Square_1
|
||||
// | \ / |
|
||||
// | \/ |
|
||||
// | /\ |
|
||||
// | / \ |
|
||||
// AddN AddN_1
|
||||
// \ /
|
||||
// y
|
||||
TrivialTestGraphInputYielder simple_graph(2, 2, 2, false,
|
||||
{"/CPU:0", "/GPU:0"});
|
||||
return simple_graph;
|
||||
}
|
||||
TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def =
|
||||
test::function::GDef({NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("foo", "NotImportant", {"bar", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
|
||||
TEST(MutableGraphViewTest, AddAndReplaceInput) {
|
||||
TrivialTestGraphInputYielder fake_input = SimpleGraph();
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
GraphDef new_graph = item.graph;
|
||||
MutableGraphView graph(&new_graph);
|
||||
// `new_bar` reads the output of an original `bar` node.
|
||||
NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NewBar", {"bar"}, {}));
|
||||
NodeDef* bar = graph.GetNode("bar");
|
||||
|
||||
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
EXPECT_EQ("AddN", input.node->name());
|
||||
EXPECT_EQ(0, input.port_id);
|
||||
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
EXPECT_EQ("Square", fanin.node->name());
|
||||
EXPECT_EQ(0, fanin.port_id);
|
||||
graph.UpdateFanouts("bar", new_bar->name());
|
||||
|
||||
EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node"));
|
||||
// Foo node must read from `new_bar`.
|
||||
NodeDef* foo = graph.GetNode("foo");
|
||||
ASSERT_NE(foo, nullptr);
|
||||
ASSERT_EQ(foo->input_size(), 2);
|
||||
EXPECT_EQ(foo->input(0), "new_bar");
|
||||
EXPECT_EQ(foo->input(1), "^new_bar");
|
||||
|
||||
NodeDef new_node = *input.node;
|
||||
new_node.set_name("new_node");
|
||||
// And the `new_bar` should read from the original `bar`.
|
||||
ASSERT_EQ(new_bar->input_size(), 1);
|
||||
ASSERT_EQ(new_bar->input(0), "bar");
|
||||
|
||||
EXPECT_EQ(graph.GetNode("new_node"), nullptr);
|
||||
NodeDef* node_in_graph = graph.AddNode(std::move(new_node));
|
||||
EXPECT_NE(graph.GetNode("new_node"), nullptr);
|
||||
// And fanouts mapping must be also updated for both nodes.
|
||||
bool include_control_fanouts = true;
|
||||
auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
|
||||
auto new_bar_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts);
|
||||
|
||||
graph.ReplaceInput(*input.node, *node_in_graph);
|
||||
EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
|
||||
}
|
||||
EXPECT_EQ(bar_fanouts.size(), 1);
|
||||
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(new_bar, 0)), 1);
|
||||
|
||||
TEST(MutableGraphViewTest, InsertNodes) {
|
||||
TrivialTestGraphInputYielder fake_input = SimpleGraph();
|
||||
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
|
||||
GraphDef new_graph = item.graph;
|
||||
MutableGraphView graph(&new_graph);
|
||||
|
||||
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
|
||||
NodeDef new_node = *input.node;
|
||||
new_node.set_name("new_node");
|
||||
new_node.set_input(0, input.node->name());
|
||||
|
||||
EXPECT_EQ(graph.GetNode("new_node"), nullptr);
|
||||
graph.InsertNode(*input.node, std::move(new_node));
|
||||
EXPECT_NE(graph.GetNode("new_node"), nullptr);
|
||||
EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y"));
|
||||
EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
|
||||
EXPECT_EQ(new_bar_fanouts.size(), 2);
|
||||
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, 0)), 1);
|
||||
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, DeleteNodes) {
|
||||
// Outputs simple graph as described in first test.
|
||||
TrivialTestGraphInputYielder fake_input = SimpleGraph();
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("other", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
|
||||
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
|
||||
GraphDef new_graph = item.graph;
|
||||
MutableGraphView graph(&new_graph);
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
EXPECT_NE(graph.GetNode("AddN"), nullptr);
|
||||
graph.DeleteNodes({"AddN"});
|
||||
EXPECT_NE(graph.GetNode("foo_1"), nullptr);
|
||||
graph.DeleteNodes({"foo_1"});
|
||||
|
||||
EXPECT_EQ(graph.GetNode("AddN"), nullptr);
|
||||
EXPECT_EQ(graph.GetNode("foo_1"), nullptr);
|
||||
|
||||
NodeDef* bar = graph.GetNode("bar");
|
||||
NodeDef* other = graph.GetNode("other");
|
||||
NodeDef* foo_2 = graph.GetNode("foo_2");
|
||||
|
||||
bool include_control_fanouts = true;
|
||||
auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
|
||||
auto other_fanouts = graph.GetFanouts(*other, include_control_fanouts);
|
||||
|
||||
EXPECT_EQ(bar_fanouts.size(), 2);
|
||||
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1);
|
||||
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1);
|
||||
|
||||
EXPECT_EQ(other_fanouts.size(), 1);
|
||||
EXPECT_EQ(other_fanouts.count(MutableGraphView::InputPort(foo_2, 0)), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -109,7 +109,7 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode(
|
||||
*first_filter_node, *second_filter_node, *fused_predicate, &graph));
|
||||
|
||||
graph.ReplaceInput(*second_filter_node, *fused_filter_node);
|
||||
graph.UpdateFanouts(second_filter_node->name(), fused_filter_node->name());
|
||||
|
||||
// TODO(prazek): we should run some optimizations on the fused filter
|
||||
// functions, or make sure that optimization passes run after filter
|
||||
|
@ -266,7 +266,7 @@ Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
const auto* stateless_map = graph.AddNode(
|
||||
MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
|
||||
|
||||
graph.ReplaceInput(*map_node, *stateless_map);
|
||||
graph.UpdateFanouts(map_node->name(), stateless_map->name());
|
||||
|
||||
// TODO(b/116285210): we could also remove map functions from library if
|
||||
// they are not used anymore.
|
||||
|
@ -96,7 +96,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
}
|
||||
|
||||
graph.InsertNode(node, MakeLatencyNode(node, &graph));
|
||||
NodeDef* latency_node = graph.AddNode(MakeLatencyNode(node, &graph));
|
||||
graph.UpdateFanouts(node.name(), latency_node->name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ Status MakeNumaAware::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
if (node.op() != "MapAndBatchDatasetV2") continue;
|
||||
|
||||
auto* numa_node = graph.AddNode(MakeNumaAwareNode(node, &graph));
|
||||
graph.ReplaceInput(node, *numa_node);
|
||||
graph.UpdateFanouts(node.name(), numa_node->name());
|
||||
nodes_to_delete.insert(node.name());
|
||||
}
|
||||
graph.DeleteNodes(nodes_to_delete);
|
||||
|
@ -113,7 +113,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
auto* new_node =
|
||||
graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
|
||||
graph.ReplaceInput(batch_node, *new_node);
|
||||
graph.UpdateFanouts(batch_node.name(), new_node->name());
|
||||
|
||||
// Mark the `Map` and `Batch` nodes for removal.
|
||||
nodes_to_delete.insert(map_node->name());
|
||||
|
@ -145,7 +145,7 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
const auto* filter_by_component = graph.AddNode(
|
||||
MakeFilterByLastComponentNode(*fused_maps, *filter_node, &graph));
|
||||
|
||||
graph.ReplaceInput(*filter_node, *filter_by_component);
|
||||
graph.UpdateFanouts(filter_node->name(), filter_by_component->name());
|
||||
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
|
||||
|
||||
// TODO(prazek): we could also remove functions from library if they are not
|
||||
|
@ -123,7 +123,7 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
const auto* fused_maps_node = graph.AddNode(
|
||||
MakeFusedNode(*parent_map_node, *map_node, *fused_function, &graph));
|
||||
|
||||
graph.ReplaceInput(*map_node, *fused_maps_node);
|
||||
graph.UpdateFanouts(map_node->name(), fused_maps_node->name());
|
||||
|
||||
// TODO(prazek): we should run some optimizations on the fused map
|
||||
// functions, or make sure that optimization passes run after map
|
||||
|
@ -83,7 +83,7 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
if (!CanParallelize(*function, function_library)) continue;
|
||||
|
||||
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
|
||||
graph.ReplaceInput(*map_node, *parallel_map);
|
||||
graph.UpdateFanouts(map_node->name(), parallel_map->name());
|
||||
nodes_to_delete.insert(map_node->name());
|
||||
}
|
||||
|
||||
|
@ -264,7 +264,7 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
auto* new_map_node = graph.AddNode(MakeNewMapNode(
|
||||
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
|
||||
graph.ReplaceInput(batch_node, *new_map_node);
|
||||
graph.UpdateFanouts(batch_node.name(), new_map_node->name());
|
||||
|
||||
// Mark the `Map` and `Batch` nodes for removal.
|
||||
nodes_to_delete.insert(map_node->name());
|
||||
|
@ -79,7 +79,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
if (!IsNoOp(node, graph)) continue;
|
||||
|
||||
NodeDef* const parent = graph_utils::GetInputNode(node, graph);
|
||||
graph.ReplaceInput(node, *parent);
|
||||
graph.UpdateFanouts(node.name(), parent->name());
|
||||
|
||||
nodes_to_delete.insert(node.name());
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
|
||||
|
||||
NodeDef* shuffle_and_repeat_node =
|
||||
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
|
||||
graph.ReplaceInput(repeat_node, *shuffle_and_repeat_node);
|
||||
graph.UpdateFanouts(repeat_node.name(), shuffle_and_repeat_node->name());
|
||||
|
||||
// Mark the `Shuffle` and `Repeat` nodes for removal.
|
||||
nodes_to_delete.insert(shuffle_node.name());
|
||||
|
Loading…
Reference in New Issue
Block a user