[Grappler] Add node fanin mutations in MutableGraphView.
PiperOrigin-RevId: 225474536
This commit is contained in:
parent
3605ae973e
commit
0b0dea8cf1
@ -176,12 +176,14 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_view",
|
||||
":grappler_item",
|
||||
":op_types",
|
||||
":utils",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
@ -191,7 +193,9 @@ tf_cc_test(
|
||||
deps = [
|
||||
":grappler_item",
|
||||
":mutable_graph_view",
|
||||
":utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
|
@ -14,14 +14,32 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsTensorIdPortValid(const TensorId& tensor_id) {
|
||||
return tensor_id.index() >= Graph::kControlSlot;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
const absl::flat_hash_set<MutableGraphView::InputPort>&
|
||||
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
|
||||
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
|
||||
@ -160,17 +178,201 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
|
||||
}
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
|
||||
NodeDef* fanin_node = GetNode(fanin.node());
|
||||
if (fanin_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int num_non_controlling_fanins =
|
||||
NumFanins(*node, /*include_controlling_nodes=*/false);
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
input.port_id = fanin.index() == Graph::kControlSlot
|
||||
? Graph::kControlSlot
|
||||
: num_non_controlling_fanins;
|
||||
|
||||
OutputPort fanin_port(fanin_node, fanin.index());
|
||||
|
||||
if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) {
|
||||
return false;
|
||||
}
|
||||
node->add_input(TensorIdToString(fanin));
|
||||
if (fanin.index() > Graph::kControlSlot) {
|
||||
int node_input_size = node->input_size() - 1;
|
||||
// If there are control dependencies in node, move newly inserted fanin to
|
||||
// be before such control dependencies.
|
||||
if (num_non_controlling_fanins < node_input_size) {
|
||||
node->mutable_input()->SwapElements(node_input_size,
|
||||
num_non_controlling_fanins);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFanin(absl::string_view node_name,
|
||||
const TensorId& fanin) {
|
||||
if (!IsTensorIdPortValid(fanin)) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return AddFanin(node, fanin);
|
||||
}
|
||||
|
||||
bool MutableGraphView::RemoveFanins(NodeDef* node,
|
||||
absl::Span<const TensorId> fanins) {
|
||||
bool modified = false;
|
||||
auto mutable_inputs = node->mutable_input();
|
||||
int curr_pos = 0;
|
||||
int num_inputs = node->input_size();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node->input(i));
|
||||
bool remove_fanin =
|
||||
std::find(fanins.begin(), fanins.end(), tensor_id) != fanins.end();
|
||||
bool update_fanin = !remove_fanin && modified;
|
||||
if (remove_fanin || update_fanin) {
|
||||
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
|
||||
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
input.port_id =
|
||||
tensor_id.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
|
||||
|
||||
if (remove_fanin) {
|
||||
fanouts()[fanin].erase(input);
|
||||
} else {
|
||||
// Shift inputs to be retained.
|
||||
if (tensor_id.index() > Graph::kControlSlot) {
|
||||
fanouts()[fanin].erase(input);
|
||||
fanouts()[fanin].insert(InputPort(node, i));
|
||||
}
|
||||
mutable_inputs->SwapElements(i, curr_pos++);
|
||||
}
|
||||
|
||||
modified = true;
|
||||
} else {
|
||||
// Skip inputs to be retained until first modification.
|
||||
curr_pos++;
|
||||
}
|
||||
}
|
||||
if (modified) {
|
||||
mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MutableGraphView::RemoveFanin(absl::string_view node_name,
|
||||
const TensorId& fanin) {
|
||||
if (!IsTensorIdPortValid(fanin)) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return RemoveFanins(node, {fanin});
|
||||
}
|
||||
|
||||
bool MutableGraphView::RemoveAllFanins(absl::string_view node_name,
|
||||
bool keep_controlling_fanins) {
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr || node->input().empty()) {
|
||||
return false;
|
||||
}
|
||||
RemoveFaninsInternal(node, keep_controlling_fanins);
|
||||
if (keep_controlling_fanins) {
|
||||
int num_non_controlling_fanins =
|
||||
NumFanins(*node, /*include_controlling_nodes=*/false);
|
||||
if (num_non_controlling_fanins == 0) {
|
||||
return false;
|
||||
} else if (num_non_controlling_fanins < node->input_size()) {
|
||||
node->mutable_input()->DeleteSubrange(0, num_non_controlling_fanins);
|
||||
} else {
|
||||
node->clear_input();
|
||||
}
|
||||
} else {
|
||||
node->clear_input();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MutableGraphView::UpdateFanin(absl::string_view node_name,
|
||||
const TensorId& from_fanin,
|
||||
const TensorId& to_fanin) {
|
||||
if (from_fanin == to_fanin || !IsTensorIdPortValid(from_fanin) ||
|
||||
!IsTensorIdPortValid(to_fanin)) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_from_fanin_control = from_fanin.index() == Graph::kControlSlot;
|
||||
bool is_to_fanin_control = to_fanin.index() == Graph::kControlSlot;
|
||||
// When replacing a non control dependency fanin with a control dependency, or
|
||||
// vice versa, remove and add, so ports can be updated properly in fanout(s).
|
||||
if (is_from_fanin_control || is_to_fanin_control) {
|
||||
bool modified = RemoveFanins(node, {from_fanin});
|
||||
if (!HasFanin(*node, to_fanin)) {
|
||||
modified |= AddFanin(node, to_fanin);
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
// In place mutation, requires no shifting of ports.
|
||||
NodeDef* from_fanin_node = GetNode(from_fanin.node());
|
||||
NodeDef* to_fanin_node = GetNode(to_fanin.node());
|
||||
if (from_fanin_node == nullptr || to_fanin_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
string to_fanin_string = TensorIdToString(to_fanin);
|
||||
int num_inputs = node->input_size();
|
||||
bool modified = false;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (ParseTensorName(node->input(i)) == from_fanin) {
|
||||
OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
|
||||
InputPort old_input;
|
||||
old_input.node = node;
|
||||
old_input.port_id =
|
||||
from_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
|
||||
fanouts()[from_fanin_port].erase(old_input);
|
||||
|
||||
OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
|
||||
InputPort new_input;
|
||||
new_input.node = node;
|
||||
new_input.port_id =
|
||||
to_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
|
||||
fanouts()[to_fanin_port].insert(new_input);
|
||||
|
||||
node->set_input(i, to_fanin_string);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
|
||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
RemoveFanouts(nodes().at(node_name_to_delete));
|
||||
RemoveFaninsInternal(nodes().at(node_name_to_delete),
|
||||
/*keep_controlling_fanins=*/false);
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
nodes().erase(node_name_to_delete);
|
||||
EraseNodesFromGraph(nodes_to_delete, graph());
|
||||
}
|
||||
|
||||
void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) {
|
||||
void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
|
||||
bool keep_controlling_fanins) {
|
||||
for (int i = 0; i < deleted_node->input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(deleted_node->input(i));
|
||||
if (keep_controlling_fanins && tensor_id.index() < 0) {
|
||||
break;
|
||||
}
|
||||
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
|
||||
|
||||
InputPort input;
|
||||
|
@ -16,7 +16,17 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -60,6 +70,38 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
// 2. foo2(new_bar:1, other:1)
|
||||
void UpdateFanouts(absl::string_view from_node, absl::string_view to_node);
|
||||
|
||||
// Add fanin to node `node_name`. If the node or fanin do not exist in the
|
||||
// graph, nothing will be modified in the graph. If fanin is a control
|
||||
// dependency, existing control dependencies will be checked first before
|
||||
// adding. Otherwise fanin will be added after existing non control dependency
|
||||
// inputs.
|
||||
//
|
||||
// This will return true iff the node is modified. If a control dependency
|
||||
// already exists, the node will not be modified.
|
||||
bool AddFanin(absl::string_view node_name, const TensorId& fanin);
|
||||
|
||||
// Remove fanin from node `node_name`. If the node or fanin do not exist in
|
||||
// the graph, nothing will be modified in the graph. If there are multiple
|
||||
// inputs that match the fanin, all of them will be removed.
|
||||
//
|
||||
// This will return true iff the node is modified. If no inputs match the
|
||||
// fanin, the node will not be modified.
|
||||
bool RemoveFanin(absl::string_view node_name, const TensorId& fanin);
|
||||
|
||||
// Remove all fanins from node `node_name`. Control dependencies will be
|
||||
// retained if keep_controlling_fanins is true.
|
||||
//
|
||||
// This will return true iff the node is modified.
|
||||
bool RemoveAllFanins(absl::string_view node_name,
|
||||
bool keep_controlling_fanins);
|
||||
|
||||
// Replace all fanins `from_fanin` with `to_fanin` in node `node_name`. If
|
||||
// the fanins or node do not exist, nothing will be modified in the graph.
|
||||
//
|
||||
// This will return true iff the node is modified.
|
||||
bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
|
||||
const TensorId& to_fanin);
|
||||
|
||||
// Deletes nodes from the graph.
|
||||
void DeleteNodes(const std::set<string>& nodes_to_delete);
|
||||
|
||||
@ -79,9 +121,22 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
// behavior is undefined.
|
||||
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
|
||||
|
||||
// Remove fanouts of the deleted node from internal state (including control
|
||||
// dependencies).
|
||||
void RemoveFanouts(NodeDef* deleted_node);
|
||||
// Remove fanins of the deleted node from internal state. Control dependencies
|
||||
// are retained iff keep_controlling_fanins is true.
|
||||
void RemoveFaninsInternal(NodeDef* deleted_node,
|
||||
bool keep_controlling_fanins);
|
||||
|
||||
// Add fanin to node. If the node or fanin do not exist in the graph, nothing
|
||||
// will be modified in the graph. If fanin is a control dependency, existing
|
||||
// control dependencies will be checked first before adding. Otherwise fanin
|
||||
// will be added after existing non control dependency inputs.
|
||||
//
|
||||
// This will return true iff the node is modified. If a control dependency
|
||||
// already exists, the node will not be modified.
|
||||
bool AddFanin(NodeDef* node, const TensorId& fanin);
|
||||
|
||||
// Remove any fanin in node that matches to a fanin in fanins.
|
||||
bool RemoveFanins(NodeDef* node, absl::Span<const TensorId> fanins);
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -16,8 +16,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -110,6 +112,356 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
|
||||
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1);
|
||||
}
|
||||
|
||||
GraphDef SimpleMutateFaninGraph() {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
|
||||
NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"a"}),
|
||||
NDef("foo_2", "NotImportant", {"b", "^a", "^c"}),
|
||||
NDef("foo_3", "NotImportant", {"b", "a:1", "a:1"}),
|
||||
NDef("foo_4", "NotImportant", {"a", "b:2", "b:2", "^c", "^d"}),
|
||||
NDef("foo_5", "NotImportant", {}),
|
||||
NDef("foo_6", "NotImportant", {"^a", "^b"})},
|
||||
/*funcs=*/{});
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
void CompareNodeInputs(const MutableGraphView& graph, const NodeDef* expected,
|
||||
NodeDef* actual) {
|
||||
ASSERT_EQ(actual->input_size(), expected->input_size());
|
||||
int port;
|
||||
for (int i = 0; i < actual->input_size(); ++i) {
|
||||
EXPECT_EQ(actual->input(i), expected->input(i));
|
||||
TensorId tensor_id = ParseTensorName(expected->input(i));
|
||||
if (tensor_id.index() == Graph::kControlSlot) {
|
||||
port = Graph::kControlSlot;
|
||||
} else {
|
||||
port = i;
|
||||
}
|
||||
MutableGraphView::InputPort input_port(actual, port);
|
||||
MutableGraphView::OutputPort output_port =
|
||||
graph.GetOutputPort(tensor_id.node(), tensor_id.index());
|
||||
EXPECT_EQ(graph.GetFanin(input_port).contains(output_port), true);
|
||||
EXPECT_EQ(graph.GetFanout(output_port).contains(input_port), true);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAddFanin(absl::string_view node_name, const TensorId& fanin_to_add,
|
||||
bool modified, const NodeDef* expected_node) {
|
||||
GraphDef graph_def = SimpleMutateFaninGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
auto node = graph.GetNode(node_name);
|
||||
if (expected_node == nullptr) {
|
||||
EXPECT_EQ(node, nullptr);
|
||||
} else {
|
||||
EXPECT_NE(node, nullptr);
|
||||
}
|
||||
|
||||
EXPECT_EQ(modified, graph.AddFanin(node_name, fanin_to_add));
|
||||
if (expected_node != nullptr) {
|
||||
CompareNodeInputs(graph, expected_node, node);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddFanin) {
|
||||
NodeDef expected_node;
|
||||
// Add input to node with 1 input 0 controls.
|
||||
expected_node = NDef("", "", {"a", "b:1"});
|
||||
TestAddFanin("foo_1", {"b", 1}, /*modified=*/true, &expected_node);
|
||||
// Add input to node with multiple inputs and 0 controls.
|
||||
expected_node = NDef("", "", {"b", "a:1", "a:1", "b:2"});
|
||||
TestAddFanin("foo_3", {"b", 2}, /*modified=*/true, &expected_node);
|
||||
// Add input to node with 1 input multiple controls.
|
||||
expected_node = NDef("", "", {"b", "a", "^c", "^a"});
|
||||
TestAddFanin("foo_2", {"a", 0}, /*modified=*/true, &expected_node);
|
||||
// Add input to node with multiple inputs and controls.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "a:1", "^d", "^c"});
|
||||
TestAddFanin("foo_4", {"a", 1}, /*modified=*/true, &expected_node);
|
||||
// Add input to node with 0 inputs 0 controls.
|
||||
expected_node = NDef("", "", {"a:1"});
|
||||
TestAddFanin("foo_5", {"a", 1}, /*modified=*/true, &expected_node);
|
||||
// Add input to node with 0 inputs multiple controls.
|
||||
expected_node = NDef("", "", {"c:1", "^b", "^a"});
|
||||
TestAddFanin("foo_6", {"c", 1}, /*modified=*/true, &expected_node);
|
||||
|
||||
// Add control to node with 1 input 0 controls.
|
||||
expected_node = NDef("", "", {"a", "^b"});
|
||||
TestAddFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with multiple inputs and 0 controls.
|
||||
expected_node = NDef("", "", {"b", "a:1", "a:1", "^c"});
|
||||
TestAddFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with 1 input multiple controls.
|
||||
expected_node = NDef("", "", {"b", "^a", "^c", "^d"});
|
||||
TestAddFanin("foo_2", {"d", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with multiple input multiple controls.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "^c", "^d", "^a"});
|
||||
TestAddFanin("foo_4", {"a", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with 0 inputs 0 controls.
|
||||
expected_node = NDef("", "", {"^a"});
|
||||
TestAddFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with 0 inputs multiple controls.
|
||||
expected_node = NDef("", "", {"^a", "^b", "^c"});
|
||||
TestAddFanin("foo_6", {"c", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Add control to node with control that already exists.
|
||||
expected_node = NDef("", "", {"b", "^a", "^c"});
|
||||
TestAddFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/false,
|
||||
&expected_node);
|
||||
|
||||
// Add fanin to node where node is missing.
|
||||
TestAddFanin("foo_missing", {"a", 0}, /*modified=*/false, nullptr);
|
||||
// Add fanin to node where fanin is missing.
|
||||
expected_node = NDef("", "", {"a"});
|
||||
TestAddFanin("foo_1", {"bar_missing", 0}, /*modified=*/false, &expected_node);
|
||||
// Add fanin to node where node and fanin are missing.
|
||||
TestAddFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
|
||||
/*expected_node=*/nullptr);
|
||||
}
|
||||
|
||||
void CheckFanout(const MutableGraphView& graph, const TensorId& fanin,
|
||||
absl::string_view node_name) {
|
||||
MutableGraphView::OutputPort output_port =
|
||||
graph.GetOutputPort(fanin.node(), fanin.index());
|
||||
auto fanouts = graph.GetFanout(output_port);
|
||||
for (auto fanout : fanouts) {
|
||||
EXPECT_NE(fanout.node->name(), fanin.node());
|
||||
}
|
||||
}
|
||||
|
||||
void TestRemoveFanin(absl::string_view node_name,
|
||||
const TensorId& fanin_to_remove, bool modified,
|
||||
const NodeDef* expected_node) {
|
||||
GraphDef graph_def = SimpleMutateFaninGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
auto node = graph.GetNode(node_name);
|
||||
if (expected_node == nullptr) {
|
||||
EXPECT_EQ(nullptr, node);
|
||||
} else {
|
||||
EXPECT_NE(nullptr, node);
|
||||
}
|
||||
|
||||
EXPECT_EQ(modified, graph.RemoveFanin(node_name, fanin_to_remove));
|
||||
if (expected_node != nullptr) {
|
||||
CompareNodeInputs(graph, expected_node, node);
|
||||
if (modified) {
|
||||
CheckFanout(graph, fanin_to_remove, node_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, RemoveFanin) {
|
||||
NodeDef expected_node;
|
||||
// Remove input from node with 1 input 0 controls.
|
||||
expected_node = NDef("", "", {});
|
||||
TestRemoveFanin("foo_1", {"a", 0}, /*modified=*/true, &expected_node);
|
||||
// Remove input from node with multiple inputs and 0 controls.
|
||||
expected_node = NDef("", "", {"b"});
|
||||
TestRemoveFanin("foo_3", {"a", 1}, /*modified=*/true, &expected_node);
|
||||
// Remove input from node with 1 input multiple controls.
|
||||
expected_node = NDef("", "", {"^a", "^c"});
|
||||
TestRemoveFanin("foo_2", {"b", 0}, /*modified=*/true, &expected_node);
|
||||
// Remove input from node with multiple inputs and controls.
|
||||
expected_node = NDef("", "", {"a", "^c", "^d"});
|
||||
TestRemoveFanin("foo_4", {"b", 2}, /*modified=*/true, &expected_node);
|
||||
|
||||
// Remove control from node with 1 input multiple controls.
|
||||
expected_node = NDef("", "", {"b", "^c"});
|
||||
TestRemoveFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Remove control from node with multiple input multiple controls.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "^c"});
|
||||
TestRemoveFanin("foo_4", {"d", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Remove control from node with 0 inputs multiple controls.
|
||||
expected_node = NDef("", "", {"^b"});
|
||||
TestRemoveFanin("foo_6", {"a", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
|
||||
// Remove input from node with 0 inputs 0 controls.
|
||||
expected_node = NDef("", "", {});
|
||||
TestRemoveFanin("foo_5", {"a", 1}, /*modified=*/false, &expected_node);
|
||||
// Remove input from node with 0 inputs multiple controls.
|
||||
expected_node = NDef("", "", {"^a", "^b"});
|
||||
TestRemoveFanin("foo_6", {"a", 1}, /*modified=*/false, &expected_node);
|
||||
// Remove control from node with 1 input 0 controls.
|
||||
expected_node = NDef("", "", {"a"});
|
||||
TestRemoveFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/false,
|
||||
&expected_node);
|
||||
// Remove control from node with multiple inputs and 0 controls.
|
||||
expected_node = NDef("", "", {"b", "a:1", "a:1"});
|
||||
TestRemoveFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/false,
|
||||
&expected_node);
|
||||
// Remove control from node with 0 inputs 0 controls.
|
||||
expected_node = NDef("", "", {});
|
||||
TestRemoveFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/false,
|
||||
&expected_node);
|
||||
|
||||
// Remove fanin from node where node is missing.
|
||||
TestRemoveFanin("foo_missing", {"a", 0}, /*modified=*/false,
|
||||
/*expected_node=*/nullptr);
|
||||
// Remove fanin from node where fanin is missing.
|
||||
expected_node = NDef("", "", {"a"});
|
||||
TestRemoveFanin("foo_1", {"bar_missing", 0}, /*modified=*/false,
|
||||
&expected_node);
|
||||
// Remove fanin from node where node and fanin are missing.
|
||||
TestRemoveFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
|
||||
/*expected_node=*/nullptr);
|
||||
}
|
||||
|
||||
void TestRemoveAllFanins(absl::string_view node_name,
|
||||
bool keep_controlling_nodes, bool modified,
|
||||
const NodeDef* expected_node) {
|
||||
GraphDef graph_def = SimpleMutateFaninGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
auto node = graph.GetNode(node_name);
|
||||
absl::flat_hash_set<string> fanin_strings;
|
||||
if (expected_node == nullptr) {
|
||||
EXPECT_EQ(node, nullptr);
|
||||
} else {
|
||||
EXPECT_NE(node, nullptr);
|
||||
fanin_strings.insert(node->input().begin(), node->input().end());
|
||||
}
|
||||
|
||||
EXPECT_EQ(modified, graph.RemoveAllFanins(node_name, keep_controlling_nodes));
|
||||
if (expected_node != nullptr) {
|
||||
CompareNodeInputs(graph, expected_node, node);
|
||||
if (modified) {
|
||||
TensorId tensor_id;
|
||||
auto retained_inputs = absl::flat_hash_set<string>(node->input().begin(),
|
||||
node->input().end());
|
||||
for (const string& fanin : fanin_strings) {
|
||||
if (!retained_inputs.contains(fanin)) {
|
||||
tensor_id = ParseTensorName(fanin);
|
||||
CheckFanout(graph, tensor_id, node_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, RemoveAllFanins) {
|
||||
NodeDef expected_node;
|
||||
// Remove all fanins from node with no control dependencies.
|
||||
expected_node = NDef("", "", {});
|
||||
TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/false,
|
||||
/*modified=*/true, &expected_node);
|
||||
// Remove all fanins from node with control dependencies.
|
||||
TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/false,
|
||||
/*modified=*/true, &expected_node);
|
||||
|
||||
// Remove all fanins from node with no control dependencies and preserve
|
||||
// control dependencies.
|
||||
TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/true,
|
||||
/*modified=*/true, &expected_node);
|
||||
// Remove all fanins from node with control dependencies and preserve control
|
||||
// dependencies.
|
||||
expected_node = NDef("", "", {"^c", "^d"});
|
||||
TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/true,
|
||||
/*modified=*/true, &expected_node);
|
||||
|
||||
// Remove all fanins from node with no fanins.
|
||||
expected_node = NDef("", "", {});
|
||||
TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/false,
|
||||
/*modified=*/false, &expected_node);
|
||||
TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/true,
|
||||
/*modified=*/false, &expected_node);
|
||||
|
||||
// Remove all fanins from node with only control dependencies.
|
||||
TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/false,
|
||||
/*modified=*/true, &expected_node);
|
||||
expected_node = NDef("", "", {"^a", "^b"});
|
||||
TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/true,
|
||||
/*modified=*/false, &expected_node);
|
||||
|
||||
// Remove all fanins from node where node is missing.
|
||||
TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/false,
|
||||
/*modified=*/false, /*expected_node=*/nullptr);
|
||||
TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/true,
|
||||
/*modified=*/false, /*expected_node=*/nullptr);
|
||||
}
|
||||
|
||||
void TestUpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
|
||||
const TensorId& to_fanin, bool modified,
|
||||
const NodeDef* expected_node) {
|
||||
GraphDef graph_def = SimpleMutateFaninGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
auto node = graph.GetNode(node_name);
|
||||
if (expected_node == nullptr) {
|
||||
EXPECT_EQ(node, nullptr);
|
||||
} else {
|
||||
EXPECT_NE(node, nullptr);
|
||||
}
|
||||
|
||||
EXPECT_EQ(modified, graph.UpdateFanin(node_name, from_fanin, to_fanin));
|
||||
if (expected_node != nullptr) {
|
||||
CompareNodeInputs(graph, expected_node, node);
|
||||
if (modified) {
|
||||
CheckFanout(graph, from_fanin, node_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, UpdateFanin) {
|
||||
NodeDef expected_node;
|
||||
// Update fanin from non control to non control.
|
||||
expected_node = NDef("", "", {"a", "b:3", "b:3", "^c", "^d"});
|
||||
TestUpdateFanin("foo_4", {"b", 2}, {"b", 3}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Update fanin from non control to control.
|
||||
expected_node = NDef("", "", {"a", "^c", "^d", "^b"});
|
||||
TestUpdateFanin("foo_4", {"b", 2}, {"b", Graph::kControlSlot},
|
||||
/*modified=*/true, &expected_node);
|
||||
// Update fanin from control to non control.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "d:1", "^c"});
|
||||
TestUpdateFanin("foo_4", {"d", Graph::kControlSlot}, {"d", 1},
|
||||
/*modified=*/true, &expected_node);
|
||||
// Update fanin from control to control.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "^d", "^b"});
|
||||
TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
|
||||
{"b", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
// Update fanin from control to existing control.
|
||||
expected_node = NDef("", "", {"a", "b:2", "b:2", "^d"});
|
||||
TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
|
||||
{"d", Graph::kControlSlot}, /*modified=*/true,
|
||||
&expected_node);
|
||||
|
||||
// Update fanin of node where from and to fanins are the same.
|
||||
expected_node = NDef("", "", {"a"});
|
||||
TestUpdateFanin("foo_1", {"a", -1}, {"a", -1}, /*modified=*/false,
|
||||
&expected_node);
|
||||
TestUpdateFanin("foo_1", {"a", 0}, {"a", 0}, /*modified=*/false,
|
||||
&expected_node);
|
||||
TestUpdateFanin("foo_1", {"a", 1}, {"a", 1}, /*modified=*/false,
|
||||
&expected_node);
|
||||
// Update fanin of node where node is missing.
|
||||
TestUpdateFanin("foo_missing", {"a", 0}, {"a", 1}, /*modified=*/false,
|
||||
/*expected_node=*/nullptr);
|
||||
// Update fanin of node where from fanin is missing.
|
||||
TestUpdateFanin("foo_1", {"from_bar_missing", 0}, {"a", 1},
|
||||
/*modified=*/false, &expected_node);
|
||||
// Update fanin of node where to fanin is missing.
|
||||
TestUpdateFanin("foo_1", {"a", 0}, {"to_bar_missing", 1}, /*modified=*/false,
|
||||
&expected_node);
|
||||
// Update fanin of node where from/to fanins and node are missing.
|
||||
TestUpdateFanin("foo_missing", {"from_bar_missing", 0}, {"to_bar_missing", 1},
|
||||
/*modified=*/false, /*expected_node=*/nullptr);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, DeleteNodes) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
|
Loading…
Reference in New Issue
Block a user