[Grappler] MutableGraphView: Update fanouts.

Current implementation if ReplaceInputs is half done/broken. Add UpdateFanouts functions that properly takes care of control dependencies, and updates internal state.

PiperOrigin-RevId: 219884693
This commit is contained in:
Eugene Zhulenev 2018-11-02 16:56:31 -07:00 committed by TensorFlower Gardener
parent 2c923299cc
commit 8a91e6adc6
16 changed files with 289 additions and 152 deletions

View File

@ -170,8 +170,10 @@ cc_library(
":graph_view",
":grappler_item",
":utils",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
@ -184,6 +186,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)

View File

@ -134,7 +134,7 @@ class GraphViewInternal {
// of an output (resp. input) port.
const absl::flat_hash_set<InputPort>& GetFanout(
const OutputPort& port) const {
return gtl::FindWithDefault(fanouts_, port, empty_set_);
return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_);
}
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
@ -173,7 +173,7 @@ class GraphViewInternal {
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
gtl::FindWithDefault(max_regular_output_port_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
@ -220,7 +220,7 @@ class GraphViewInternal {
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlling_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
gtl::FindWithDefault(max_regular_output_port_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
@ -241,7 +241,7 @@ class GraphViewInternal {
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_edges ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
gtl::FindWithDefault(max_regular_output_port_, &node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
@ -290,29 +290,42 @@ class GraphViewInternal {
if (output.port_id < 0) {
fanouts_[output].emplace(node, -1);
} else {
num_regular_outputs_[output.node] =
std::max(num_regular_outputs_[output.node], output.port_id);
max_regular_output_port_[output.node] =
std::max(max_regular_output_port_[output.node], output.port_id);
fanouts_[output].emplace(node, i);
}
}
}
// Access to the mutable internal state for MutableGraphView.
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
return &nodes_;
absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; }
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() {
return fanouts_;
}
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
mutable_fanouts() {
return &fanouts_;
absl::flat_hash_map<const NodeDef*, int>& max_regular_output_port() {
return max_regular_output_port_;
}
private:
GraphDefT* graph_; // must outlive the graph view
// A mapping from the node name to the node itself.
absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
absl::flat_hash_set<InputPort> empty_set_;
// A mapping from the output port to all inputs that read from it.
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
std::unordered_map<NodeDefT*, int> num_regular_outputs_;
// Keep a maximum index of tensor fetched from the node. It doesn't guarantee
// that all tensors in the [0, max_regular_output_port] range are actually
// fetched by other nodes.
absl::flat_hash_map<const NodeDef*, int> max_regular_output_port_;
// If the node has no fanouts at given output port (output tensor consumers)
// we return a reference to this set from `GetFanout` (we can't construct new
// empty set every time, because we need a non-dangling reference).
absl::flat_hash_set<InputPort> fanout_not_found_value_;
};
} // namespace internal

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
@ -47,52 +50,137 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
return node_in_graph;
}
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
const int output_port_id) {
auto* node_in_graph = graph()->add_node();
*node_in_graph = std::move(node);
AddUniqueNodeOrDie(node_in_graph);
// replace input for the output nodes of `input_node` with `node`
ReplaceInput(input_node, *node_in_graph, output_port_id);
AddFanouts(node_in_graph);
return node_in_graph;
void MutableGraphView::UpdateFanouts(absl::string_view from_node,
absl::string_view to_node) {
NodeDef* from_node_ptr = GetNode(from_node);
NodeDef* to_node_ptr = GetNode(to_node);
if (from_node_ptr && to_node_ptr) {
UpdateFanouts(from_node_ptr, to_node_ptr);
} else if (!from_node_ptr) {
LOG(WARNING) << absl::Substitute(
"Can't update fanouts from '$0' to '$1', from node was not found.",
from_node, to_node);
} else {
LOG(WARNING) << absl::Substitute(
"Can't update fanouts from '$0' to '$1', to node was not found.",
from_node, to_node);
}
}
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
const NodeDef& new_input,
const int output_port_id) {
OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
auto fanout = GetFanout(output_port);
for (auto& input_port : fanout) {
input_port.node->set_input(input_port.port_id, new_input.name());
AddFanouts(input_port.node);
void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
VLOG(0) << absl::Substitute("Update fanouts from '$0' to '$1'.",
from_node->name(), to_node->name());
// Update internal state with the new output_port->input_port edge.
const auto add_edge = [this](const OutputPort& output_port,
const InputPort& input_port) {
fanouts()[output_port].insert(input_port);
};
// Remove invalidated edge from the internal state.
const auto remove_edge = [this](const OutputPort& output_port,
const InputPort& input_port) {
fanouts()[output_port].erase(input_port);
};
// First we update regular fanouts. For the regular fanouts
// `input_port:port_id` is the input index in NodeDef.
auto regular_edges =
GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
// Maximum index of the `from_node` output tensor that is still used as an
// input to some other node.
int keep_max_regular_output_port = -1;
for (const Edge& edge : regular_edges) {
const OutputPort output_port = edge.src;
const InputPort input_port = edge.dst;
// If the `to_node` reads from the `from_node`, skip this edge (see
// AddAndUpdateFanoutsWithoutSelfLoops test for an example).
if (input_port.node == to_node) {
keep_max_regular_output_port =
std::max(keep_max_regular_output_port, input_port.port_id);
continue;
}
// Update input at destination node.
input_port.node->set_input(
input_port.port_id,
output_port.port_id == 0
? to_node->name()
: absl::StrCat(to_node->name(), ":", output_port.port_id));
// Remove old edge between the `from_node` and the fanout node.
remove_edge(output_port, input_port);
// Add an edge between the `to_node` and new fanout node.
add_edge(OutputPort(to_node, output_port.port_id), input_port);
}
// For the control fanouts we do not know the input index in a NodeDef,
// so we have to traverse all control inputs.
auto control_fanouts =
GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
if (control_fanouts.empty()) return;
const string from_control_input = absl::StrCat("^", from_node->name());
const string to_control_input = absl::StrCat("^", to_node->name());
for (const InputPort& control_port : control_fanouts) {
// Node can't be control dependency of itself.
if (control_port.node == to_node) continue;
// Find and update input corresponding to control dependency.
NodeDef* node = control_port.node;
for (int i = node->input_size() - 1; i >= 0; --i) {
const string& input = node->input(i);
if (!IsControlInput(input)) break; // we reached regular inputs
if (input == from_control_input) {
node->set_input(i, to_control_input);
}
}
// Remove old edge between the `from_node` and the fanout node.
remove_edge(OutputPort(from_node, Graph::kControlSlot), control_port);
// Add an edge between the `to_node` and new fanout node.
add_edge(OutputPort(to_node, Graph::kControlSlot), control_port);
}
// Because we update all regular fanouts of `from_node`, we can just copy
// the value `num_regular_outputs`.
max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
// Check if all fanouts were updated to read from the `to_node`.
if (keep_max_regular_output_port >= 0) {
max_regular_output_port()[from_node] = keep_max_regular_output_port;
} else {
max_regular_output_port().erase(from_node);
}
}
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
for (const string& node_name_to_delete : nodes_to_delete)
RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
RemoveFanouts(nodes().at(node_name_to_delete));
for (const string& node_name_to_delete : nodes_to_delete)
mutable_nodes()->erase(node_name_to_delete);
nodes().erase(node_name_to_delete);
EraseNodesFromGraph(nodes_to_delete, graph());
}
void MutableGraphView::RemoveFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node->input(i));
OutputPort fanin((*mutable_nodes())[tensor_id.node()], tensor_id.index());
void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) {
for (int i = 0; i < deleted_node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(deleted_node->input(i));
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
InputPort input;
input.node = node;
input.node = deleted_node;
if (tensor_id.index() < 0)
input.port_id = -1;
input.port_id = Graph::kControlSlot;
else
input.port_id = i;
(*mutable_fanouts())[fanin].erase(input);
fanouts()[fanin].erase(input);
}
}

View File

@ -44,31 +44,44 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
using GraphViewInternal::GetRegularFanin;
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
// Adds a new node to graph and updates the view.
// Adds a new node to graph and updates the view. Returns a pointer to the
// node in graph.
NodeDef* AddNode(NodeDef&& node);
// Inserts a new node to the graph after `input` node and updates the view.
// This adds `node` to the graph and replaces the input for the output
// nodes of `input` with a port `output_port_id` with the new node.
NodeDef* InsertNode(const NodeDef& input, NodeDef&& node,
int output_port_id = 0);
// Replaces the input for the output nodes of 'old_input' with a port
// `output_port_id` with 'new_input'.
// Updates all fanouts (input ports fetching output tensors) from `from_node`
// to the `to_node`, including control dependencies.
//
// E.g: We have 2 nodes that use 'bar' node outputs as inputs:
// foo(bar:0, bar:1), foo2(other:0, bar:0)
// Calling ReplaceInput(bar, new, 0) changes every occurrence of bar:0 for
// new:0. Result:
// foo(new:0, bar:1), foo2(other:0, new:0)
void ReplaceInput(const NodeDef& old_input, const NodeDef& new_input,
int output_port_id = 0);
// Example: We have 2 nodes that use `bar` node output tensors as inputs:
// 1. foo1(bar:0, bar:1, other:0, ^bar)
// 2. foo2(bar:1, other:1)
//
// After calling ForwardOutputs(bar, new_bar):
// 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar)
// 2. foo2(new_bar:1, other:1)
void UpdateFanouts(absl::string_view from_node, absl::string_view to_node);
// Deletes nodes from the graph.
void DeleteNodes(const std::set<string>& nodes_to_delete);
private:
void RemoveFanouts(NodeDef* node);
// Updates all fanouts (input ports fetching output tensors) from `from_node`
// to the `to_node`, including control dependencies.
//
// Example: We have 2 nodes that use `bar` node output tensors as inputs:
// 1. foo1(bar:0, bar:1, other:0, ^bar)
// 2. foo2(bar:1, other:1)
//
// After calling ForwardOutputs(bar, new_bar):
// 1. foo1(new_bar:0, new_bar:1, other:0, ^new_bar)
// 2. foo2(new_bar:1, other:1)
//
// IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the
// behavior is undefined.
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
// Remove fanouts of the deleted node from internal state (including control
// dependencies).
void RemoveFanouts(NodeDef* deleted_node);
};
} // end namespace grappler

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/platform/test.h"
@ -23,104 +24,122 @@ namespace tensorflow {
namespace grappler {
namespace {
bool FindChildWithName(const MutableGraphView& graph,
const string& output_port_name,
const string& input_name) {
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(output_port_name, 0);
auto fanout = graph.GetFanout(output_port);
for (auto& input_port : fanout) {
if (input_port.node->name() == input_name) return true;
}
return false;
using ::tensorflow::test::function::NDef;
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("bar", "NotImportant", {}, {}),
NDef("other", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
/* empty function library */ {});
MutableGraphView graph(&graph_def);
NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NotImportant", {}, {}));
NodeDef* bar = graph.GetNode("bar");
graph.UpdateFanouts(bar->name(), new_bar->name());
// Fanout nodes must have their inputs updated.
NodeDef* foo_1 = graph.GetNode("foo_1");
ASSERT_NE(foo_1, nullptr);
ASSERT_EQ(foo_1->input_size(), 4);
EXPECT_EQ(foo_1->input(0), "new_bar");
EXPECT_EQ(foo_1->input(1), "other");
EXPECT_EQ(foo_1->input(2), "new_bar:1");
EXPECT_EQ(foo_1->input(3), "^new_bar");
NodeDef* foo_2 = graph.GetNode("foo_2");
ASSERT_NE(foo_2, nullptr);
ASSERT_EQ(foo_2->input_size(), 3);
EXPECT_EQ(foo_2->input(0), "other:1");
EXPECT_EQ(foo_2->input(1), "new_bar:2");
EXPECT_EQ(foo_2->input(2), "^new_bar");
// And fanouts mapping must be also updated for both nodes.
bool include_control_fanouts = true;
auto old_node_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
auto new_node_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts);
EXPECT_TRUE(old_node_fanouts.empty());
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 0)), 1);
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, 2)), 1);
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_1, -1)), 1);
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1);
EXPECT_EQ(new_node_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1);
}
TrivialTestGraphInputYielder SimpleGraph() {
// This outputs simple graph like:
// x
// / \
// Square Square_1
// | \ / |
// | \/ |
// | /\ |
// | / \ |
// AddN AddN_1
// \ /
// y
TrivialTestGraphInputYielder simple_graph(2, 2, 2, false,
{"/CPU:0", "/GPU:0"});
return simple_graph;
}
TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
// Actual node.op() is not important in this test.
GraphDef graph_def =
test::function::GDef({NDef("bar", "NotImportant", {}, {}),
NDef("foo", "NotImportant", {"bar", "^bar"})},
/* empty function library */ {});
TEST(MutableGraphViewTest, AddAndReplaceInput) {
TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
MutableGraphView graph(&graph_def);
GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph);
// `new_bar` reads the output of an original `bar` node.
NodeDef* new_bar = graph.AddNode(NDef("new_bar", "NewBar", {"bar"}, {}));
NodeDef* bar = graph.GetNode("bar");
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
EXPECT_EQ("AddN", input.node->name());
EXPECT_EQ(0, input.port_id);
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
graph.UpdateFanouts("bar", new_bar->name());
EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node"));
// Foo node must read from `new_bar`.
NodeDef* foo = graph.GetNode("foo");
ASSERT_NE(foo, nullptr);
ASSERT_EQ(foo->input_size(), 2);
EXPECT_EQ(foo->input(0), "new_bar");
EXPECT_EQ(foo->input(1), "^new_bar");
NodeDef new_node = *input.node;
new_node.set_name("new_node");
// And the `new_bar` should read from the original `bar`.
ASSERT_EQ(new_bar->input_size(), 1);
ASSERT_EQ(new_bar->input(0), "bar");
EXPECT_EQ(graph.GetNode("new_node"), nullptr);
NodeDef* node_in_graph = graph.AddNode(std::move(new_node));
EXPECT_NE(graph.GetNode("new_node"), nullptr);
// And fanouts mapping must be also updated for both nodes.
bool include_control_fanouts = true;
auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
auto new_bar_fanouts = graph.GetFanouts(*new_bar, include_control_fanouts);
graph.ReplaceInput(*input.node, *node_in_graph);
EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node"));
EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
}
EXPECT_EQ(bar_fanouts.size(), 1);
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(new_bar, 0)), 1);
TEST(MutableGraphViewTest, InsertNodes) {
TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph);
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
NodeDef new_node = *input.node;
new_node.set_name("new_node");
new_node.set_input(0, input.node->name());
EXPECT_EQ(graph.GetNode("new_node"), nullptr);
graph.InsertNode(*input.node, std::move(new_node));
EXPECT_NE(graph.GetNode("new_node"), nullptr);
EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN"));
EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1"));
EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN"));
EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1"));
EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node"));
EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y"));
EXPECT_TRUE(FindChildWithName(graph, "new_node", "y"));
EXPECT_EQ(new_bar_fanouts.size(), 2);
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, 0)), 1);
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1);
}
TEST(MutableGraphViewTest, DeleteNodes) {
// Outputs simple graph as described in first test.
TrivialTestGraphInputYielder fake_input = SimpleGraph();
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("bar", "NotImportant", {}, {}),
NDef("other", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
/* empty function library */ {});
GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph);
MutableGraphView graph(&graph_def);
EXPECT_NE(graph.GetNode("AddN"), nullptr);
graph.DeleteNodes({"AddN"});
EXPECT_NE(graph.GetNode("foo_1"), nullptr);
graph.DeleteNodes({"foo_1"});
EXPECT_EQ(graph.GetNode("AddN"), nullptr);
EXPECT_EQ(graph.GetNode("foo_1"), nullptr);
NodeDef* bar = graph.GetNode("bar");
NodeDef* other = graph.GetNode("other");
NodeDef* foo_2 = graph.GetNode("foo_2");
bool include_control_fanouts = true;
auto bar_fanouts = graph.GetFanouts(*bar, include_control_fanouts);
auto other_fanouts = graph.GetFanouts(*other, include_control_fanouts);
EXPECT_EQ(bar_fanouts.size(), 2);
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, 1)), 1);
EXPECT_EQ(bar_fanouts.count(MutableGraphView::InputPort(foo_2, -1)), 1);
EXPECT_EQ(other_fanouts.size(), 1);
EXPECT_EQ(other_fanouts.count(MutableGraphView::InputPort(foo_2, 0)), 1);
}
} // namespace

View File

@ -109,7 +109,7 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode(
*first_filter_node, *second_filter_node, *fused_predicate, &graph));
graph.ReplaceInput(*second_filter_node, *fused_filter_node);
graph.UpdateFanouts(second_filter_node->name(), fused_filter_node->name());
// TODO(prazek): we should run some optimizations on the fused filter
// functions, or make sure that optimization passes run after filter

View File

@ -266,7 +266,7 @@ Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto* stateless_map = graph.AddNode(
MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
graph.ReplaceInput(*map_node, *stateless_map);
graph.UpdateFanouts(map_node->name(), stateless_map->name());
// TODO(b/116285210): we could also remove map functions from library if
// they are not used anymore.

View File

@ -96,7 +96,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
graph.InsertNode(node, MakeLatencyNode(node, &graph));
NodeDef* latency_node = graph.AddNode(MakeLatencyNode(node, &graph));
graph.UpdateFanouts(node.name(), latency_node->name());
}
return Status::OK();
}

View File

@ -47,7 +47,7 @@ Status MakeNumaAware::Optimize(Cluster* cluster, const GrapplerItem& item,
if (node.op() != "MapAndBatchDatasetV2") continue;
auto* numa_node = graph.AddNode(MakeNumaAwareNode(node, &graph));
graph.ReplaceInput(node, *numa_node);
graph.UpdateFanouts(node.name(), numa_node->name());
nodes_to_delete.insert(node.name());
}
graph.DeleteNodes(nodes_to_delete);

View File

@ -113,7 +113,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* new_node =
graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
graph.ReplaceInput(batch_node, *new_node);
graph.UpdateFanouts(batch_node.name(), new_node->name());
// Mark the `Map` and `Batch` nodes for removal.
nodes_to_delete.insert(map_node->name());

View File

@ -145,7 +145,7 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto* filter_by_component = graph.AddNode(
MakeFilterByLastComponentNode(*fused_maps, *filter_node, &graph));
graph.ReplaceInput(*filter_node, *filter_by_component);
graph.UpdateFanouts(filter_node->name(), filter_by_component->name());
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
// TODO(prazek): we could also remove functions from library if they are not

View File

@ -123,7 +123,7 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto* fused_maps_node = graph.AddNode(
MakeFusedNode(*parent_map_node, *map_node, *fused_function, &graph));
graph.ReplaceInput(*map_node, *fused_maps_node);
graph.UpdateFanouts(map_node->name(), fused_maps_node->name());
// TODO(prazek): we should run some optimizations on the fused map
// functions, or make sure that optimization passes run after map

View File

@ -83,7 +83,7 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!CanParallelize(*function, function_library)) continue;
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
graph.UpdateFanouts(map_node->name(), parallel_map->name());
nodes_to_delete.insert(map_node->name());
}

View File

@ -264,7 +264,7 @@ Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* new_map_node = graph.AddNode(MakeNewMapNode(
*map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
graph.ReplaceInput(batch_node, *new_map_node);
graph.UpdateFanouts(batch_node.name(), new_map_node->name());
// Mark the `Map` and `Batch` nodes for removal.
nodes_to_delete.insert(map_node->name());

View File

@ -79,7 +79,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsNoOp(node, graph)) continue;
NodeDef* const parent = graph_utils::GetInputNode(node, graph);
graph.ReplaceInput(node, *parent);
graph.UpdateFanouts(node.name(), parent->name());
nodes_to_delete.insert(node.name());
}

View File

@ -86,7 +86,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
NodeDef* shuffle_and_repeat_node =
graph.AddNode(make_shuffle_and_repeat_node(shuffle_node, repeat_node));
graph.ReplaceInput(repeat_node, *shuffle_and_repeat_node);
graph.UpdateFanouts(repeat_node.name(), shuffle_and_repeat_node->name());
// Mark the `Shuffle` and `Repeat` nodes for removal.
nodes_to_delete.insert(shuffle_node.name());