[Grappler] Add node fanin mutations in MutableGraphView.
PiperOrigin-RevId: 225474536
This commit is contained in:
		
							parent
							
								
									3605ae973e
								
							
						
					
					
						commit
						0b0dea8cf1
					
				@ -176,12 +176,14 @@ cc_library(
 | 
			
		||||
    visibility = ["//visibility:public"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":graph_view",
 | 
			
		||||
        ":grappler_item",
 | 
			
		||||
        ":op_types",
 | 
			
		||||
        ":utils",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:protos_all_cc",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
        "@com_google_absl//absl/types:span",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -191,7 +193,9 @@ tf_cc_test(
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":grappler_item",
 | 
			
		||||
        ":mutable_graph_view",
 | 
			
		||||
        ":utils",
 | 
			
		||||
        "//tensorflow/cc:cc_ops",
 | 
			
		||||
        "//tensorflow/core:graph",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "//tensorflow/core:test_main",
 | 
			
		||||
        "//tensorflow/core:testlib",
 | 
			
		||||
 | 
			
		||||
@ -14,14 +14,32 @@ limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
#include "absl/strings/str_cat.h"
 | 
			
		||||
#include "absl/strings/substitute.h"
 | 
			
		||||
#include "tensorflow/core/framework/graph.pb.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def.pb.h"
 | 
			
		||||
#include "tensorflow/core/graph/graph.h"
 | 
			
		||||
#include "tensorflow/core/graph/tensor_id.h"
 | 
			
		||||
#include "tensorflow/core/grappler/op_types.h"
 | 
			
		||||
#include "tensorflow/core/grappler/utils.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/stringpiece.h"
 | 
			
		||||
#include "tensorflow/core/platform/types.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace grappler {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
bool IsTensorIdPortValid(const TensorId& tensor_id) {
 | 
			
		||||
  return tensor_id.index() >= Graph::kControlSlot;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
const absl::flat_hash_set<MutableGraphView::InputPort>&
 | 
			
		||||
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
 | 
			
		||||
  return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
 | 
			
		||||
@ -160,17 +178,201 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
 | 
			
		||||
  NodeDef* fanin_node = GetNode(fanin.node());
 | 
			
		||||
  if (fanin_node == nullptr) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int num_non_controlling_fanins =
 | 
			
		||||
      NumFanins(*node, /*include_controlling_nodes=*/false);
 | 
			
		||||
  InputPort input;
 | 
			
		||||
  input.node = node;
 | 
			
		||||
  input.port_id = fanin.index() == Graph::kControlSlot
 | 
			
		||||
                      ? Graph::kControlSlot
 | 
			
		||||
                      : num_non_controlling_fanins;
 | 
			
		||||
 | 
			
		||||
  OutputPort fanin_port(fanin_node, fanin.index());
 | 
			
		||||
 | 
			
		||||
  if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  node->add_input(TensorIdToString(fanin));
 | 
			
		||||
  if (fanin.index() > Graph::kControlSlot) {
 | 
			
		||||
    int node_input_size = node->input_size() - 1;
 | 
			
		||||
    // If there are control dependencies in node, move newly inserted fanin to
 | 
			
		||||
    // be before such control dependencies.
 | 
			
		||||
    if (num_non_controlling_fanins < node_input_size) {
 | 
			
		||||
      node->mutable_input()->SwapElements(node_input_size,
 | 
			
		||||
                                          num_non_controlling_fanins);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::AddFanin(absl::string_view node_name,
 | 
			
		||||
                                const TensorId& fanin) {
 | 
			
		||||
  if (!IsTensorIdPortValid(fanin)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  NodeDef* node = GetNode(node_name);
 | 
			
		||||
  if (node == nullptr) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  return AddFanin(node, fanin);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::RemoveFanins(NodeDef* node,
 | 
			
		||||
                                    absl::Span<const TensorId> fanins) {
 | 
			
		||||
  bool modified = false;
 | 
			
		||||
  auto mutable_inputs = node->mutable_input();
 | 
			
		||||
  int curr_pos = 0;
 | 
			
		||||
  int num_inputs = node->input_size();
 | 
			
		||||
  for (int i = 0; i < num_inputs; ++i) {
 | 
			
		||||
    TensorId tensor_id = ParseTensorName(node->input(i));
 | 
			
		||||
    bool remove_fanin =
 | 
			
		||||
        std::find(fanins.begin(), fanins.end(), tensor_id) != fanins.end();
 | 
			
		||||
    bool update_fanin = !remove_fanin && modified;
 | 
			
		||||
    if (remove_fanin || update_fanin) {
 | 
			
		||||
      OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
 | 
			
		||||
 | 
			
		||||
      InputPort input;
 | 
			
		||||
      input.node = node;
 | 
			
		||||
      input.port_id =
 | 
			
		||||
          tensor_id.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
 | 
			
		||||
 | 
			
		||||
      if (remove_fanin) {
 | 
			
		||||
        fanouts()[fanin].erase(input);
 | 
			
		||||
      } else {
 | 
			
		||||
        // Shift inputs to be retained.
 | 
			
		||||
        if (tensor_id.index() > Graph::kControlSlot) {
 | 
			
		||||
          fanouts()[fanin].erase(input);
 | 
			
		||||
          fanouts()[fanin].insert(InputPort(node, i));
 | 
			
		||||
        }
 | 
			
		||||
        mutable_inputs->SwapElements(i, curr_pos++);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      modified = true;
 | 
			
		||||
    } else {
 | 
			
		||||
      // Skip inputs to be retained until first modification.
 | 
			
		||||
      curr_pos++;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if (modified) {
 | 
			
		||||
    mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
 | 
			
		||||
  }
 | 
			
		||||
  return modified;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::RemoveFanin(absl::string_view node_name,
 | 
			
		||||
                                   const TensorId& fanin) {
 | 
			
		||||
  if (!IsTensorIdPortValid(fanin)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  NodeDef* node = GetNode(node_name);
 | 
			
		||||
  if (node == nullptr) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  return RemoveFanins(node, {fanin});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::RemoveAllFanins(absl::string_view node_name,
 | 
			
		||||
                                       bool keep_controlling_fanins) {
 | 
			
		||||
  NodeDef* node = GetNode(node_name);
 | 
			
		||||
  if (node == nullptr || node->input().empty()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  RemoveFaninsInternal(node, keep_controlling_fanins);
 | 
			
		||||
  if (keep_controlling_fanins) {
 | 
			
		||||
    int num_non_controlling_fanins =
 | 
			
		||||
        NumFanins(*node, /*include_controlling_nodes=*/false);
 | 
			
		||||
    if (num_non_controlling_fanins == 0) {
 | 
			
		||||
      return false;
 | 
			
		||||
    } else if (num_non_controlling_fanins < node->input_size()) {
 | 
			
		||||
      node->mutable_input()->DeleteSubrange(0, num_non_controlling_fanins);
 | 
			
		||||
    } else {
 | 
			
		||||
      node->clear_input();
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    node->clear_input();
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool MutableGraphView::UpdateFanin(absl::string_view node_name,
 | 
			
		||||
                                   const TensorId& from_fanin,
 | 
			
		||||
                                   const TensorId& to_fanin) {
 | 
			
		||||
  if (from_fanin == to_fanin || !IsTensorIdPortValid(from_fanin) ||
 | 
			
		||||
      !IsTensorIdPortValid(to_fanin)) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  NodeDef* node = GetNode(node_name);
 | 
			
		||||
  if (node == nullptr) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool is_from_fanin_control = from_fanin.index() == Graph::kControlSlot;
 | 
			
		||||
  bool is_to_fanin_control = to_fanin.index() == Graph::kControlSlot;
 | 
			
		||||
  // When replacing a non control dependency fanin with a control dependency, or
 | 
			
		||||
  // vice versa, remove and add, so ports can be updated properly in fanout(s).
 | 
			
		||||
  if (is_from_fanin_control || is_to_fanin_control) {
 | 
			
		||||
    bool modified = RemoveFanins(node, {from_fanin});
 | 
			
		||||
    if (!HasFanin(*node, to_fanin)) {
 | 
			
		||||
      modified |= AddFanin(node, to_fanin);
 | 
			
		||||
    }
 | 
			
		||||
    return modified;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // In place mutation, requires no shifting of ports.
 | 
			
		||||
  NodeDef* from_fanin_node = GetNode(from_fanin.node());
 | 
			
		||||
  NodeDef* to_fanin_node = GetNode(to_fanin.node());
 | 
			
		||||
  if (from_fanin_node == nullptr || to_fanin_node == nullptr) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  string to_fanin_string = TensorIdToString(to_fanin);
 | 
			
		||||
  int num_inputs = node->input_size();
 | 
			
		||||
  bool modified = false;
 | 
			
		||||
  for (int i = 0; i < num_inputs; ++i) {
 | 
			
		||||
    if (ParseTensorName(node->input(i)) == from_fanin) {
 | 
			
		||||
      OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
 | 
			
		||||
      InputPort old_input;
 | 
			
		||||
      old_input.node = node;
 | 
			
		||||
      old_input.port_id =
 | 
			
		||||
          from_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
 | 
			
		||||
      fanouts()[from_fanin_port].erase(old_input);
 | 
			
		||||
 | 
			
		||||
      OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
 | 
			
		||||
      InputPort new_input;
 | 
			
		||||
      new_input.node = node;
 | 
			
		||||
      new_input.port_id =
 | 
			
		||||
          to_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
 | 
			
		||||
      fanouts()[to_fanin_port].insert(new_input);
 | 
			
		||||
 | 
			
		||||
      node->set_input(i, to_fanin_string);
 | 
			
		||||
      modified = true;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return modified;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
 | 
			
		||||
  for (const string& node_name_to_delete : nodes_to_delete)
 | 
			
		||||
    RemoveFanouts(nodes().at(node_name_to_delete));
 | 
			
		||||
    RemoveFaninsInternal(nodes().at(node_name_to_delete),
 | 
			
		||||
                         /*keep_controlling_fanins=*/false);
 | 
			
		||||
  for (const string& node_name_to_delete : nodes_to_delete)
 | 
			
		||||
    nodes().erase(node_name_to_delete);
 | 
			
		||||
  EraseNodesFromGraph(nodes_to_delete, graph());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) {
 | 
			
		||||
void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
 | 
			
		||||
                                            bool keep_controlling_fanins) {
 | 
			
		||||
  for (int i = 0; i < deleted_node->input_size(); ++i) {
 | 
			
		||||
    TensorId tensor_id = ParseTensorName(deleted_node->input(i));
 | 
			
		||||
    if (keep_controlling_fanins && tensor_id.index() < 0) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
 | 
			
		||||
 | 
			
		||||
    InputPort input;
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,17 @@ limitations under the License.
 | 
			
		||||
#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
 | 
			
		||||
#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
 | 
			
		||||
 | 
			
		||||
#include <set>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
#include "absl/container/flat_hash_set.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "absl/types/span.h"
 | 
			
		||||
#include "tensorflow/core/framework/graph.pb.h"
 | 
			
		||||
#include "tensorflow/core/framework/node_def.pb.h"
 | 
			
		||||
#include "tensorflow/core/graph/tensor_id.h"
 | 
			
		||||
#include "tensorflow/core/grappler/graph_view.h"
 | 
			
		||||
#include "tensorflow/core/platform/types.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace grappler {
 | 
			
		||||
@ -60,6 +70,38 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
 | 
			
		||||
  //   2. foo2(new_bar:1, other:1)
 | 
			
		||||
  void UpdateFanouts(absl::string_view from_node, absl::string_view to_node);
 | 
			
		||||
 | 
			
		||||
  // Add fanin to node `node_name`. If the node or fanin do not exist in the
 | 
			
		||||
  // graph, nothing will be modified in the graph. If fanin is a control
 | 
			
		||||
  // dependency, existing control dependencies will be checked first before
 | 
			
		||||
  // adding. Otherwise fanin will be added after existing non control dependency
 | 
			
		||||
  // inputs.
 | 
			
		||||
  //
 | 
			
		||||
  // This will return true iff the node is modified. If a control dependency
 | 
			
		||||
  // already exists, the node will not be modified.
 | 
			
		||||
  bool AddFanin(absl::string_view node_name, const TensorId& fanin);
 | 
			
		||||
 | 
			
		||||
  // Remove fanin from node `node_name`. If the node or fanin do not exist in
 | 
			
		||||
  // the graph, nothing will be modified in the graph. If there are multiple
 | 
			
		||||
  // inputs that match the fanin, all of them will be removed.
 | 
			
		||||
  //
 | 
			
		||||
  // This will return true iff the node is modified. If no inputs match the
 | 
			
		||||
  // fanin, the node will not be modified.
 | 
			
		||||
  bool RemoveFanin(absl::string_view node_name, const TensorId& fanin);
 | 
			
		||||
 | 
			
		||||
  // Remove all fanins from node `node_name`. Control dependencies will be
 | 
			
		||||
  // retained if keep_controlling_fanins is true.
 | 
			
		||||
  //
 | 
			
		||||
  // This will return true iff the node is modified.
 | 
			
		||||
  bool RemoveAllFanins(absl::string_view node_name,
 | 
			
		||||
                       bool keep_controlling_fanins);
 | 
			
		||||
 | 
			
		||||
  // Replace all fanins `from_fanin` with `to_fanin` in node `node_name`. If
 | 
			
		||||
  // the fanins or node do not exist, nothing will be modified in the graph.
 | 
			
		||||
  //
 | 
			
		||||
  // This will return true iff the node is modified.
 | 
			
		||||
  bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
 | 
			
		||||
                   const TensorId& to_fanin);
 | 
			
		||||
 | 
			
		||||
  // Deletes nodes from the graph.
 | 
			
		||||
  void DeleteNodes(const std::set<string>& nodes_to_delete);
 | 
			
		||||
 | 
			
		||||
@ -79,9 +121,22 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
 | 
			
		||||
  // behavior is undefined.
 | 
			
		||||
  void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
 | 
			
		||||
 | 
			
		||||
  // Remove fanouts of the deleted node from internal state (including control
 | 
			
		||||
  // dependencies).
 | 
			
		||||
  void RemoveFanouts(NodeDef* deleted_node);
 | 
			
		||||
  // Remove fanins of the deleted node from internal state. Control dependencies
 | 
			
		||||
  // are retained iff keep_controlling_fanins is true.
 | 
			
		||||
  void RemoveFaninsInternal(NodeDef* deleted_node,
 | 
			
		||||
                            bool keep_controlling_fanins);
 | 
			
		||||
 | 
			
		||||
  // Add fanin to node. If the node or fanin do not exist in the graph, nothing
 | 
			
		||||
  // will be modified in the graph. If fanin is a control dependency, existing
 | 
			
		||||
  // control dependencies will be checked first before adding. Otherwise fanin
 | 
			
		||||
  // will be added after existing non control dependency inputs.
 | 
			
		||||
  //
 | 
			
		||||
  // This will return true iff the node is modified. If a control dependency
 | 
			
		||||
  // already exists, the node will not be modified.
 | 
			
		||||
  bool AddFanin(NodeDef* node, const TensorId& fanin);
 | 
			
		||||
 | 
			
		||||
  // Remove any fanin in node that matches to a fanin in fanins.
 | 
			
		||||
  bool RemoveFanins(NodeDef* node, absl::Span<const TensorId> fanins);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // end namespace grappler
 | 
			
		||||
 | 
			
		||||
@ -16,8 +16,10 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
 | 
			
		||||
#include "tensorflow/cc/ops/standard_ops.h"
 | 
			
		||||
#include "tensorflow/core/framework/function_testlib.h"
 | 
			
		||||
#include "tensorflow/core/graph/tensor_id.h"
 | 
			
		||||
#include "tensorflow/core/grappler/grappler_item.h"
 | 
			
		||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
 | 
			
		||||
#include "tensorflow/core/grappler/utils.h"
 | 
			
		||||
#include "tensorflow/core/platform/test.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
@ -110,6 +112,356 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
 | 
			
		||||
  EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
GraphDef SimpleMutateFaninGraph() {
 | 
			
		||||
  // Actual node.op() is not important in this test.
 | 
			
		||||
  GraphDef graph_def = test::function::GDef(
 | 
			
		||||
      {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
 | 
			
		||||
       NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
 | 
			
		||||
       NDef("foo_1", "NotImportant", {"a"}),
 | 
			
		||||
       NDef("foo_2", "NotImportant", {"b", "^a", "^c"}),
 | 
			
		||||
       NDef("foo_3", "NotImportant", {"b", "a:1", "a:1"}),
 | 
			
		||||
       NDef("foo_4", "NotImportant", {"a", "b:2", "b:2", "^c", "^d"}),
 | 
			
		||||
       NDef("foo_5", "NotImportant", {}),
 | 
			
		||||
       NDef("foo_6", "NotImportant", {"^a", "^b"})},
 | 
			
		||||
      /*funcs=*/{});
 | 
			
		||||
  return graph_def;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CompareNodeInputs(const MutableGraphView& graph, const NodeDef* expected,
 | 
			
		||||
                       NodeDef* actual) {
 | 
			
		||||
  ASSERT_EQ(actual->input_size(), expected->input_size());
 | 
			
		||||
  int port;
 | 
			
		||||
  for (int i = 0; i < actual->input_size(); ++i) {
 | 
			
		||||
    EXPECT_EQ(actual->input(i), expected->input(i));
 | 
			
		||||
    TensorId tensor_id = ParseTensorName(expected->input(i));
 | 
			
		||||
    if (tensor_id.index() == Graph::kControlSlot) {
 | 
			
		||||
      port = Graph::kControlSlot;
 | 
			
		||||
    } else {
 | 
			
		||||
      port = i;
 | 
			
		||||
    }
 | 
			
		||||
    MutableGraphView::InputPort input_port(actual, port);
 | 
			
		||||
    MutableGraphView::OutputPort output_port =
 | 
			
		||||
        graph.GetOutputPort(tensor_id.node(), tensor_id.index());
 | 
			
		||||
    EXPECT_EQ(graph.GetFanin(input_port).contains(output_port), true);
 | 
			
		||||
    EXPECT_EQ(graph.GetFanout(output_port).contains(input_port), true);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TestAddFanin(absl::string_view node_name, const TensorId& fanin_to_add,
 | 
			
		||||
                  bool modified, const NodeDef* expected_node) {
 | 
			
		||||
  GraphDef graph_def = SimpleMutateFaninGraph();
 | 
			
		||||
 | 
			
		||||
  MutableGraphView graph(&graph_def);
 | 
			
		||||
 | 
			
		||||
  auto node = graph.GetNode(node_name);
 | 
			
		||||
  if (expected_node == nullptr) {
 | 
			
		||||
    EXPECT_EQ(node, nullptr);
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_NE(node, nullptr);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(modified, graph.AddFanin(node_name, fanin_to_add));
 | 
			
		||||
  if (expected_node != nullptr) {
 | 
			
		||||
    CompareNodeInputs(graph, expected_node, node);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MutableGraphViewTest, AddFanin) {
 | 
			
		||||
  NodeDef expected_node;
 | 
			
		||||
  // Add input to node with 1 input 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:1"});
 | 
			
		||||
  TestAddFanin("foo_1", {"b", 1}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Add input to node with multiple inputs and 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "a:1", "a:1", "b:2"});
 | 
			
		||||
  TestAddFanin("foo_3", {"b", 2}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Add input to node with 1 input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "a", "^c", "^a"});
 | 
			
		||||
  TestAddFanin("foo_2", {"a", 0}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Add input to node with multiple inputs and controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "a:1", "^d", "^c"});
 | 
			
		||||
  TestAddFanin("foo_4", {"a", 1}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Add input to node with 0 inputs 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a:1"});
 | 
			
		||||
  TestAddFanin("foo_5", {"a", 1}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Add input to node with 0 inputs multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"c:1", "^b", "^a"});
 | 
			
		||||
  TestAddFanin("foo_6", {"c", 1}, /*modified=*/true, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Add control to node with 1 input 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "^b"});
 | 
			
		||||
  TestAddFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with multiple inputs and 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "a:1", "a:1", "^c"});
 | 
			
		||||
  TestAddFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with 1 input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "^a", "^c", "^d"});
 | 
			
		||||
  TestAddFanin("foo_2", {"d", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with multiple input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "^c", "^d", "^a"});
 | 
			
		||||
  TestAddFanin("foo_4", {"a", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with 0 inputs 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"^a"});
 | 
			
		||||
  TestAddFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with 0 inputs multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"^a", "^b", "^c"});
 | 
			
		||||
  TestAddFanin("foo_6", {"c", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
  // Add control to node with control that already exists.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "^a", "^c"});
 | 
			
		||||
  TestAddFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/false,
 | 
			
		||||
               &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Add fanin to node where node is missing.
 | 
			
		||||
  TestAddFanin("foo_missing", {"a", 0}, /*modified=*/false, nullptr);
 | 
			
		||||
  // Add fanin to node where fanin is missing.
 | 
			
		||||
  expected_node = NDef("", "", {"a"});
 | 
			
		||||
  TestAddFanin("foo_1", {"bar_missing", 0}, /*modified=*/false, &expected_node);
 | 
			
		||||
  // Add fanin to node where node and fanin are missing.
 | 
			
		||||
  TestAddFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
 | 
			
		||||
               /*expected_node=*/nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void CheckFanout(const MutableGraphView& graph, const TensorId& fanin,
 | 
			
		||||
                 absl::string_view node_name) {
 | 
			
		||||
  MutableGraphView::OutputPort output_port =
 | 
			
		||||
      graph.GetOutputPort(fanin.node(), fanin.index());
 | 
			
		||||
  auto fanouts = graph.GetFanout(output_port);
 | 
			
		||||
  for (auto fanout : fanouts) {
 | 
			
		||||
    EXPECT_NE(fanout.node->name(), fanin.node());
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TestRemoveFanin(absl::string_view node_name,
 | 
			
		||||
                     const TensorId& fanin_to_remove, bool modified,
 | 
			
		||||
                     const NodeDef* expected_node) {
 | 
			
		||||
  GraphDef graph_def = SimpleMutateFaninGraph();
 | 
			
		||||
 | 
			
		||||
  MutableGraphView graph(&graph_def);
 | 
			
		||||
 | 
			
		||||
  auto node = graph.GetNode(node_name);
 | 
			
		||||
  if (expected_node == nullptr) {
 | 
			
		||||
    EXPECT_EQ(nullptr, node);
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_NE(nullptr, node);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(modified, graph.RemoveFanin(node_name, fanin_to_remove));
 | 
			
		||||
  if (expected_node != nullptr) {
 | 
			
		||||
    CompareNodeInputs(graph, expected_node, node);
 | 
			
		||||
    if (modified) {
 | 
			
		||||
      CheckFanout(graph, fanin_to_remove, node_name);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MutableGraphViewTest, RemoveFanin) {
 | 
			
		||||
  NodeDef expected_node;
 | 
			
		||||
  // Remove input from node with 1 input 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {});
 | 
			
		||||
  TestRemoveFanin("foo_1", {"a", 0}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Remove input from node with multiple inputs and 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b"});
 | 
			
		||||
  TestRemoveFanin("foo_3", {"a", 1}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Remove input from node with 1 input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"^a", "^c"});
 | 
			
		||||
  TestRemoveFanin("foo_2", {"b", 0}, /*modified=*/true, &expected_node);
 | 
			
		||||
  // Remove input from node with multiple inputs and controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "^c", "^d"});
 | 
			
		||||
  TestRemoveFanin("foo_4", {"b", 2}, /*modified=*/true, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove control from node with 1 input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "^c"});
 | 
			
		||||
  TestRemoveFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Remove control from node with multiple input multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "^c"});
 | 
			
		||||
  TestRemoveFanin("foo_4", {"d", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Remove control from node with 0 inputs multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"^b"});
 | 
			
		||||
  TestRemoveFanin("foo_6", {"a", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove input from node with 0 inputs 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {});
 | 
			
		||||
  TestRemoveFanin("foo_5", {"a", 1}, /*modified=*/false, &expected_node);
 | 
			
		||||
  // Remove input from node with 0 inputs multiple controls.
 | 
			
		||||
  expected_node = NDef("", "", {"^a", "^b"});
 | 
			
		||||
  TestRemoveFanin("foo_6", {"a", 1}, /*modified=*/false, &expected_node);
 | 
			
		||||
  // Remove control from node with 1 input 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"a"});
 | 
			
		||||
  TestRemoveFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Remove control from node with multiple inputs and 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {"b", "a:1", "a:1"});
 | 
			
		||||
  TestRemoveFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Remove control from node with 0 inputs 0 controls.
 | 
			
		||||
  expected_node = NDef("", "", {});
 | 
			
		||||
  TestRemoveFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove fanin from node where node is missing.
 | 
			
		||||
  TestRemoveFanin("foo_missing", {"a", 0}, /*modified=*/false,
 | 
			
		||||
                  /*expected_node=*/nullptr);
 | 
			
		||||
  // Remove fanin from node where fanin is missing.
 | 
			
		||||
  expected_node = NDef("", "", {"a"});
 | 
			
		||||
  TestRemoveFanin("foo_1", {"bar_missing", 0}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Remove fanin from node where node and fanin are missing.
 | 
			
		||||
  TestRemoveFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
 | 
			
		||||
                  /*expected_node=*/nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TestRemoveAllFanins(absl::string_view node_name,
 | 
			
		||||
                         bool keep_controlling_nodes, bool modified,
 | 
			
		||||
                         const NodeDef* expected_node) {
 | 
			
		||||
  GraphDef graph_def = SimpleMutateFaninGraph();
 | 
			
		||||
 | 
			
		||||
  MutableGraphView graph(&graph_def);
 | 
			
		||||
 | 
			
		||||
  auto node = graph.GetNode(node_name);
 | 
			
		||||
  absl::flat_hash_set<string> fanin_strings;
 | 
			
		||||
  if (expected_node == nullptr) {
 | 
			
		||||
    EXPECT_EQ(node, nullptr);
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_NE(node, nullptr);
 | 
			
		||||
    fanin_strings.insert(node->input().begin(), node->input().end());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(modified, graph.RemoveAllFanins(node_name, keep_controlling_nodes));
 | 
			
		||||
  if (expected_node != nullptr) {
 | 
			
		||||
    CompareNodeInputs(graph, expected_node, node);
 | 
			
		||||
    if (modified) {
 | 
			
		||||
      TensorId tensor_id;
 | 
			
		||||
      auto retained_inputs = absl::flat_hash_set<string>(node->input().begin(),
 | 
			
		||||
                                                         node->input().end());
 | 
			
		||||
      for (const string& fanin : fanin_strings) {
 | 
			
		||||
        if (!retained_inputs.contains(fanin)) {
 | 
			
		||||
          tensor_id = ParseTensorName(fanin);
 | 
			
		||||
          CheckFanout(graph, tensor_id, node_name);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MutableGraphViewTest, RemoveAllFanins) {
 | 
			
		||||
  NodeDef expected_node;
 | 
			
		||||
  // Remove all fanins from node with no control dependencies.
 | 
			
		||||
  expected_node = NDef("", "", {});
 | 
			
		||||
  TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/false,
 | 
			
		||||
                      /*modified=*/true, &expected_node);
 | 
			
		||||
  // Remove all fanins from node with control dependencies.
 | 
			
		||||
  TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/false,
 | 
			
		||||
                      /*modified=*/true, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove all fanins from node with no control dependencies and preserve
 | 
			
		||||
  // control dependencies.
 | 
			
		||||
  TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/true,
 | 
			
		||||
                      /*modified=*/true, &expected_node);
 | 
			
		||||
  // Remove all fanins from node with control dependencies and preserve control
 | 
			
		||||
  // dependencies.
 | 
			
		||||
  expected_node = NDef("", "", {"^c", "^d"});
 | 
			
		||||
  TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/true,
 | 
			
		||||
                      /*modified=*/true, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove all fanins from node with no fanins.
 | 
			
		||||
  expected_node = NDef("", "", {});
 | 
			
		||||
  TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/false,
 | 
			
		||||
                      /*modified=*/false, &expected_node);
 | 
			
		||||
  TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/true,
 | 
			
		||||
                      /*modified=*/false, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove all fanins from node with only control dependencies.
 | 
			
		||||
  TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/false,
 | 
			
		||||
                      /*modified=*/true, &expected_node);
 | 
			
		||||
  expected_node = NDef("", "", {"^a", "^b"});
 | 
			
		||||
  TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/true,
 | 
			
		||||
                      /*modified=*/false, &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Remove all fanins from node where node is missing.
 | 
			
		||||
  TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/false,
 | 
			
		||||
                      /*modified=*/false, /*expected_node=*/nullptr);
 | 
			
		||||
  TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/true,
 | 
			
		||||
                      /*modified=*/false, /*expected_node=*/nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void TestUpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
 | 
			
		||||
                     const TensorId& to_fanin, bool modified,
 | 
			
		||||
                     const NodeDef* expected_node) {
 | 
			
		||||
  GraphDef graph_def = SimpleMutateFaninGraph();
 | 
			
		||||
 | 
			
		||||
  MutableGraphView graph(&graph_def);
 | 
			
		||||
 | 
			
		||||
  auto node = graph.GetNode(node_name);
 | 
			
		||||
  if (expected_node == nullptr) {
 | 
			
		||||
    EXPECT_EQ(node, nullptr);
 | 
			
		||||
  } else {
 | 
			
		||||
    EXPECT_NE(node, nullptr);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EXPECT_EQ(modified, graph.UpdateFanin(node_name, from_fanin, to_fanin));
 | 
			
		||||
  if (expected_node != nullptr) {
 | 
			
		||||
    CompareNodeInputs(graph, expected_node, node);
 | 
			
		||||
    if (modified) {
 | 
			
		||||
      CheckFanout(graph, from_fanin, node_name);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MutableGraphViewTest, UpdateFanin) {
 | 
			
		||||
  NodeDef expected_node;
 | 
			
		||||
  // Update fanin from non control to non control.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:3", "b:3", "^c", "^d"});
 | 
			
		||||
  TestUpdateFanin("foo_4", {"b", 2}, {"b", 3}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Update fanin from non control to control.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "^c", "^d", "^b"});
 | 
			
		||||
  TestUpdateFanin("foo_4", {"b", 2}, {"b", Graph::kControlSlot},
 | 
			
		||||
                  /*modified=*/true, &expected_node);
 | 
			
		||||
  // Update fanin from control to non control.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "d:1", "^c"});
 | 
			
		||||
  TestUpdateFanin("foo_4", {"d", Graph::kControlSlot}, {"d", 1},
 | 
			
		||||
                  /*modified=*/true, &expected_node);
 | 
			
		||||
  // Update fanin from control to control.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "^d", "^b"});
 | 
			
		||||
  TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
 | 
			
		||||
                  {"b", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Update fanin from control to existing control.
 | 
			
		||||
  expected_node = NDef("", "", {"a", "b:2", "b:2", "^d"});
 | 
			
		||||
  TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
 | 
			
		||||
                  {"d", Graph::kControlSlot}, /*modified=*/true,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
 | 
			
		||||
  // Update fanin of node where from and to fanins are the same.
 | 
			
		||||
  expected_node = NDef("", "", {"a"});
 | 
			
		||||
  TestUpdateFanin("foo_1", {"a", -1}, {"a", -1}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  TestUpdateFanin("foo_1", {"a", 0}, {"a", 0}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  TestUpdateFanin("foo_1", {"a", 1}, {"a", 1}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Update fanin of node where node is missing.
 | 
			
		||||
  TestUpdateFanin("foo_missing", {"a", 0}, {"a", 1}, /*modified=*/false,
 | 
			
		||||
                  /*expected_node=*/nullptr);
 | 
			
		||||
  // Update fanin of node where from fanin is missing.
 | 
			
		||||
  TestUpdateFanin("foo_1", {"from_bar_missing", 0}, {"a", 1},
 | 
			
		||||
                  /*modified=*/false, &expected_node);
 | 
			
		||||
  // Update fanin of node where to fanin is missing.
 | 
			
		||||
  TestUpdateFanin("foo_1", {"a", 0}, {"to_bar_missing", 1}, /*modified=*/false,
 | 
			
		||||
                  &expected_node);
 | 
			
		||||
  // Update fanin of node where from/to fanins and node are missing.
 | 
			
		||||
  TestUpdateFanin("foo_missing", {"from_bar_missing", 0}, {"to_bar_missing", 1},
 | 
			
		||||
                  /*modified=*/false, /*expected_node=*/nullptr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(MutableGraphViewTest, DeleteNodes) {
 | 
			
		||||
  // Actual node.op() is not important in this test.
 | 
			
		||||
  GraphDef graph_def = test::function::GDef(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user