From ba6a8875dc3becbeff7e74c6df33b7746af2b691 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Mon, 20 May 2019 08:15:24 -0700 Subject: [PATCH] [Grappler] New GraphView for mutating graphs with validation of node connectivity and better error handling. PiperOrigin-RevId: 249051462 --- tensorflow/core/grappler/utils/BUILD | 68 + tensorflow/core/grappler/utils/graph_view.cc | 1450 ++++++++++ tensorflow/core/grappler/utils/graph_view.h | 490 ++++ .../core/grappler/utils/graph_view_internal.h | 898 ++++++ .../utils/graph_view_internal_test.cc | 1112 +++++++ .../core/grappler/utils/graph_view_test.cc | 2545 +++++++++++++++++ 6 files changed, 6563 insertions(+) create mode 100644 tensorflow/core/grappler/utils/graph_view.cc create mode 100644 tensorflow/core/grappler/utils/graph_view.h create mode 100644 tensorflow/core/grappler/utils/graph_view_internal.h create mode 100644 tensorflow/core/grappler/utils/graph_view_internal_test.cc create mode 100644 tensorflow/core/grappler/utils/graph_view_test.cc diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 2c3ae8b9d96..8b8a3b4ef5d 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -303,3 +303,71 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "graph_view_internal", + hdrs = ["graph_view_internal.h"], + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "graph_view_internal_test", + srcs = ["graph_view_internal_test.cc"], + deps = [ + ":graph_view", + ":graph_view_internal", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "graph_view", + srcs = ["graph_view.cc"], + hdrs = ["graph_view.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_view_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) + +tf_cc_test( + name = "graph_view_test", + srcs = ["graph_view_test.cc"], + deps = [ + ":graph_view", + ":grappler_test", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/core/grappler/utils/graph_view.cc b/tensorflow/core/grappler/utils/graph_view.cc new file mode 100644 index 00000000000..3d823d2daf9 --- /dev/null +++ b/tensorflow/core/grappler/utils/graph_view.cc @@ -0,0 +1,1450 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/graph_view.h" + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { +namespace grappler { +namespace utils { + +FaninView::FaninView(NodeView* node_view, int index) + : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_, + index) {} + +FanoutView::FanoutView(NodeView* node_view, int index) + : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_, + index) {} + +const NodeDef* NodeView::node() const { + return &graph_view_->graph()->node(node_index_); +} + +bool NodeView::HasFanin(const FanoutView& fanin) const { + if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) { + return false; + } + return fanins_set_.contains( + {&graph_view_->graph_->node(fanin.node_index_), fanin.index()}); +} + +bool NodeView::HasFanout(const FaninView& fanout) const { + if (fanout.index() < Graph::kControlSlot || + graph_view_ != fanout.graph_view_) { + return false; + } + NodeView* view = fanout.node_view(); + if (view == nullptr) { + return false; + } else if (fanout.index() == Graph::kControlSlot) { + return view->fanins_set_.contains({this->node(), Graph::kControlSlot}); + } else if (fanout.index() >= view->regular_fanins_.size()) { + return false; + } + return view->regular_fanins_[fanout.index()].node_index_ == node_index_; +} + +inline const FanoutView& NodeView::GetMissingFanin() const { + return graph_view_->missing_fanin_; +} + +inline const std::vector& NodeView::GetMissingFanout() const { + return graph_view_->missing_fanout_; +} + +namespace { +const char kGraphViewError[] = "GraphView::GraphView error: "; +} // namespace + +GraphView::GraphView(const GraphDef* graph, Status* status) + : GraphViewInternal(graph) { + const int num_nodes = graph->node_size(); + node_index_by_name_.reserve(num_nodes); + nodes_.reserve(num_nodes); + for (const NodeDef& node : graph->node()) { + if (!AddUniqueNodeInternal(&node)) { + *status = errors::InvalidArgument( + kGraphViewError, "graph has multiple nodes with the name '", + node.name(), "'."); + Reset(); + return; + } + } + Status s; + for (NodeView& node_view : nodes_) { + s = CheckAndAddFaninsInternal(&node_view); + if (!s.ok()) { + *status = s; + Reset(); + return; + } + } + *status = Status::OK(); +} + +bool GraphView::AddUniqueNodeInternal(const NodeDef* node) { + const int node_index = node_index_by_name_.size(); + auto it = node_index_by_name_.emplace(node->name(), node_index); + if (it.second) { + nodes_.emplace_back(this, node_index); + return true; + } + return false; +} + +Status GraphView::CheckAndAddFaninsInternal(NodeView* node_view) { + bool has_observed_control = false; + const NodeDef* node = node_view->node(); + const string& node_name = node->name(); + const int node_index = node_view->node_index_; + node_view->fanins_set_.reserve(node->input_size()); + for (const string& input : node->input()) { + TensorId fanin_id = ParseTensorName(input); + if (fanin_id.node() == node_name) { + return errors::InvalidArgument(kGraphViewError, "node '", node_name, + "' has self cycle fanin '", input, "'."); + } + bool is_control = IsTensorIdControl(fanin_id); + if (!is_control && has_observed_control) { + return errors::InvalidArgument(kGraphViewError, "node '", node_name, + "' has regular fanin '", input, + "' after controlling fanins."); + } + auto it = node_index_by_name_.find(fanin_id.node()); + if (it == node_index_by_name_.end()) { + return errors::InvalidArgument(kGraphViewError, "node '", node_name, + "' has missing fanin '", input, "'."); + } + const int fanin_node_index = it->second; + NodeView& fanin_node_view = nodes_[fanin_node_index]; + + if (is_control) { + fanin_node_view.controlled_fanouts_.emplace_back(this, node_index, + Graph::kControlSlot); + node_view->controlling_fanins_.emplace_back(this, fanin_node_index, + Graph::kControlSlot); + node_view->fanins_set_.emplace(fanin_node_view.node(), + Graph::kControlSlot); + has_observed_control = true; + } else { + if (fanin_node_view.regular_fanouts_by_port_.size() < + fanin_id.index() + 1) { + fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1); + } + fanin_node_view.regular_fanouts_by_port_[fanin_id.index()].emplace_back( + this, node_index, node_view->regular_fanins_.size()); + ++fanin_node_view.num_regular_fanouts_; + node_view->regular_fanins_.emplace_back(this, fanin_node_index, + fanin_id.index()); + node_view->fanins_set_.emplace(fanin_node_view.node(), fanin_id.index()); + } + } + return Status::OK(); +} + +MutableFaninView::MutableFaninView(MutableNodeView* node_view, int index) + : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_, + index) {} + +MutableFanoutView::MutableFanoutView(MutableNodeView* node_view, int index) + : NodeIndexAndPortIndex(node_view->graph_view_, node_view->node_index_, + index) {} + +NodeDef* MutableNodeView::node() const { + return graph_view_->graph()->mutable_node(node_index_); +} + +bool MutableNodeView::HasFanin(const MutableFanoutView& fanin) const { + if (fanin.index() < Graph::kControlSlot || graph_view_ != fanin.graph_view_) { + return false; + } + return fanins_count_.contains( + {&graph_view_->graph_->node(fanin.node_index_), fanin.index()}); +} + +bool MutableNodeView::HasFanout(const MutableFaninView& fanout) const { + if (fanout.index() < Graph::kControlSlot || + graph_view_ != fanout.graph_view_) { + return false; + } + MutableNodeView* view = fanout.node_view(); + if (view == nullptr) { + return false; + } else if (fanout.index() == Graph::kControlSlot) { + return view->fanins_count_.contains({this->node(), Graph::kControlSlot}); + } else if (fanout.index() >= view->regular_fanins_.size()) { + return false; + } + return view->regular_fanins_[fanout.index()].node_index_ == node_index_; +} + +const MutableFanoutView& MutableNodeView::GetMissingFanin() const { + return graph_view_->missing_fanin_; +} + +const std::vector& MutableNodeView::GetMissingFanout() const { + return graph_view_->missing_fanout_; +} + +namespace { +const char kMutationAddNodeError[] = "Mutation::AddNode error: "; + +bool IsTensorIdRegular(const TensorId& tensor_id) { + return tensor_id.index() >= 0; +} +} // namespace + +Mutation::Mutation(MutableGraphView* graph_view) : graph_view_(graph_view) {} + +MutationNewNode Mutation::AddNode(NodeDef&& node, Status* status) { + bool has_observed_control = false; + const string& node_name = node.name(); + std::vector regular_fanins; + absl::flat_hash_set controlling_fanins; + const int num_fanins = node.input_size(); + for (int i = 0; i < num_fanins; ++i) { + const string& input = node.input(i); + TensorId fanin_id = ParseTensorName(input); + if (fanin_id.node() == node_name) { + *status = + errors::InvalidArgument(kMutationAddNodeError, "node '", node_name, + "' has self cycle fanin '", input, "'."); + return MutationNewNode(this, mutation_counter_, internal::kMissingIndex); + } + bool is_control = IsTensorIdControl(fanin_id); + if (is_control) { + has_observed_control = true; + controlling_fanins.emplace(fanin_id.node()); + } else if (has_observed_control) { + *status = errors::InvalidArgument(kMutationAddNodeError, "node '", + node_name, "' has regular fanin '", + input, "' after controlling fanins."); + return MutationNewNode(this, mutation_counter_, internal::kMissingIndex); + } else { + regular_fanins.emplace_back(fanin_id); + } + } + + node.mutable_input()->Clear(); + new_nodes_.emplace_back(graph_view_, std::move(node)); + MutationNewNodeHolder& mutation_node = new_nodes_.back(); + mutation_node.regular_fanins = std::move(regular_fanins); + mutation_node.num_regular_fanins = mutation_node.regular_fanins.size(); + mutation_node.controlling_fanins = std::move(controlling_fanins); + *status = Status::OK(); + return MutationNewNode(this, mutation_counter_, new_nodes_.size() - 1); +} + +void Mutation::AddMutation( + MutableNodeView* node, + std::function mutate_fn) { + DCHECK(node->graph_view_ == graph_view_); + if (node->update_index_ == internal::kMissingIndex) { + node->update_index_ = updated_nodes_.size(); + updated_nodes_.emplace_back(graph_view_, node->node_index_); + mutate_fn(&updated_nodes_.back()); + } else { + auto& diff = updated_nodes_[node->update_index_]; + if (!diff.removed) { + mutate_fn(&diff); + } + } +} + +void Mutation::RemoveNode(MutableNodeView* node) { + AddMutation(node, [](MutableNodeViewDiff* diff) { + // Clear existing MutableNodeViewDiff as when node is removed no change to + // its internal state matter. + internal::Reset(diff); + internal::SetRemoved(diff, true); + }); +} + +void Mutation::UpdateNodeName(MutableNodeView* node, absl::string_view name) { + AddMutation(node, [name](MutableNodeViewDiff* diff) { + internal::UpdateName(diff, name); + }); +} + +void Mutation::UpdateNodeName(const MutationNewNode& node, + absl::string_view name) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::UpdateName(&new_nodes_[node.index_], name); +} + +void Mutation::UpdateNodeOp(MutableNodeView* node, absl::string_view op) { + AddMutation( + node, [op](MutableNodeViewDiff* diff) { internal::UpdateOp(diff, op); }); +} + +void Mutation::UpdateNodeOp(const MutationNewNode& node, absl::string_view op) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::UpdateOp(&new_nodes_[node.index_], op); +} + +void Mutation::UpdateNodeDevice(MutableNodeView* node, + absl::string_view device) { + AddMutation(node, [device](MutableNodeViewDiff* diff) { + internal::UpdateDevice(diff, device); + }); +} + +void Mutation::UpdateNodeDevice(const MutationNewNode& node, + absl::string_view device) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::UpdateDevice(&new_nodes_[node.index_], device); +} + +void Mutation::AddOrUpdateRegularFanin(MutableNodeView* node, int index, + const TensorId& fanin) { + AddMutation(node, [index, fanin](MutableNodeViewDiff* diff) { + internal::AddOrUpdateRegularFanin(diff, index, fanin); + }); +} + +void Mutation::AddOrUpdateRegularFanin(const MutationNewNode& node, int index, + const TensorId& fanin) { + DCHECK(node.mutation_ == this && + node.mutation_counter_ == mutation_counter_ && index >= 0 && + IsTensorIdRegular(fanin)); + internal::AddOrUpdateRegularFanin(&new_nodes_[node.index_], index, fanin); +} + +void Mutation::RemoveRegularFanin(MutableNodeView* node, int index) { + AddMutation(node, [index](MutableNodeViewDiff* diff) { + internal::RemoveRegularFanin(diff, index); + }); +} + +void Mutation::RemoveRegularFanin(const MutationNewNode& node, int index) { + DCHECK(node.mutation_ == this && + node.mutation_counter_ == mutation_counter_ && index >= 0); + internal::RemoveRegularFanin(&new_nodes_[node.index_], index); +} + +void Mutation::AddControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name) { + AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) { + auto it = node->controlling_fanins_index_.find(fanin_node_name); + const int control_index = it != node->controlling_fanins_index_.end() + ? it->second + : internal::kMissingIndex; + internal::AddControllingFanin(diff, control_index, fanin_node_name); + }); +} + +void Mutation::AddControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::AddControllingFanin(&new_nodes_[node.index_], fanin_node_name); +} + +void Mutation::RemoveControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name) { + AddMutation(node, [node, fanin_node_name](MutableNodeViewDiff* diff) { + auto it = node->controlling_fanins_index_.find(fanin_node_name); + const int control_index = it != node->controlling_fanins_index_.end() + ? it->second + : internal::kMissingIndex; + internal::RemoveControllingFanin(diff, control_index, fanin_node_name); + }); +} + +void Mutation::RemoveControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::RemoveControllingFanin(&new_nodes_[node.index_], fanin_node_name); +} + +void Mutation::AddOrUpdateNodeAttr(MutableNodeView* node, + absl::string_view attr_name, + const AttrValue& attr_value) { + AddMutation(node, [attr_name, attr_value](MutableNodeViewDiff* diff) { + internal::AddOrUpdateAttribute(diff, attr_name, attr_value); + }); +} + +void Mutation::AddOrUpdateNodeAttr(const MutationNewNode& node, + absl::string_view attr_name, + const AttrValue& attr_value) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::AddOrUpdateAttribute(&new_nodes_[node.index_], attr_name, + attr_value); +} + +void Mutation::RemoveNodeAttr(MutableNodeView* node, + absl::string_view attr_name) { + AddMutation(node, [attr_name](MutableNodeViewDiff* diff) { + internal::RemoveAttribute(diff, attr_name); + }); +} + +void Mutation::RemoveNodeAttr(const MutationNewNode& node, + absl::string_view attr_name) { + DCHECK(node.mutation_ == this && node.mutation_counter_ == mutation_counter_); + internal::RemoveAttribute(&new_nodes_[node.index_], attr_name); +} + +void Mutation::ResetInternal() { + std::vector().swap(updated_nodes_); + std::vector().swap(new_nodes_); +} + +void Mutation::Reset() { + for (const auto& update : updated_nodes_) { + graph_view_->nodes_[update.node_index].update_index_ = + internal::kMissingIndex; + } + ResetInternal(); +} + +Status Mutation::Apply() { return graph_view_->ApplyMutationInternal(); } + +namespace { +const char kMutableGraphViewError[] = + "MutableGraphView::MutableGraphView error: "; + +const char kMutableGraphViewApplyError[] = "Mutation::Apply error: "; + +inline void IncrementFaninCount( + absl::flat_hash_map* fanins_count, + const internal::NodeDefAndPortIndex& fanin) { + ++(*fanins_count)[fanin]; +} + +inline void DecrementFaninCount( + absl::flat_hash_map* fanins_count, + const internal::NodeDefAndPortIndex& fanin) { + auto it = fanins_count->find(fanin); + if (it != fanins_count->end()) { + if (it->second <= 1) { + fanins_count->erase(it); + } else { + --it->second; + } + } +} +} // namespace + +MutableGraphView::MutableGraphView(GraphDef* graph, Status* status) + : GraphViewInternal(graph), mutation_(Mutation(this)) { + const int num_nodes = graph->node_size(); + node_index_by_name_.reserve(num_nodes); + nodes_.reserve(num_nodes); + for (NodeDef& node : *graph->mutable_node()) { + if (!AddUniqueNodeInternal(&node)) { + *status = errors::InvalidArgument( + kMutableGraphViewError, "graph has multiple nodes with the name '", + node.name(), "'."); + Reset(); + return; + } + } + std::vector> fanins; + Status s = CheckFaninsInternal(&fanins); + if (!s.ok()) { + *status = s; + Reset(); + return; + } + AddFaninsInternal(&fanins); + *status = Status::OK(); +} + +Mutation* MutableGraphView::GetMutationBuilder() { return &mutation_; } + +bool MutableGraphView::AddUniqueNodeInternal(NodeDef* node) { + const int node_index = node_index_by_name_.size(); + auto it = node_index_by_name_.emplace(node->name(), node_index); + if (it.second) { + nodes_.emplace_back(this, node_index); + return true; + } + return false; +} + +Status MutableGraphView::CheckFaninsInternal( + std::vector>* fanins) { + const int num_nodes = nodes_.size(); + fanins->reserve(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + bool has_observed_control = false; + const NodeDef* node = nodes_[i].node(); + const string& node_name = node->name(); + std::vector node_fanins; + node_fanins.reserve(node->input_size()); + for (const string& input : node->input()) { + TensorId fanin_id = ParseTensorName(input); + if (fanin_id.node() == node_name) { + return errors::InvalidArgument(kMutableGraphViewError, "node '", + node_name, "' has self cycle fanin '", + input, "'."); + } + bool is_control = IsTensorIdControl(fanin_id); + if (!is_control && has_observed_control) { + return errors::InvalidArgument(kMutableGraphViewError, "node '", + node_name, "' has regular fanin '", + input, "' after controlling fanins."); + } + if (!node_index_by_name_.contains(fanin_id.node())) { + return errors::InvalidArgument(kMutableGraphViewError, "node '", + node_name, "' has missing fanin '", + input, "'."); + } + if (is_control) { + has_observed_control = true; + } + node_fanins.push_back(std::move(fanin_id)); + } + fanins->push_back(std::move(node_fanins)); + } + return Status::OK(); +} + +void MutableGraphView::AddFaninsInternal( + std::vector>* fanins) { + const int num_nodes = nodes_.size(); + for (int i = 0; i < num_nodes; ++i) { + MutableNodeView& node_view = nodes_[i]; + NodeDef* node = node_view.node(); + std::vector& node_fanins = fanins->at(i); + absl::flat_hash_set observed_controls; + int pos = 0; + const int last_idx = node_fanins.size() - 1; + int last_pos = last_idx; + node_view.fanins_count_.reserve(node->input_size()); + node_view.controlling_fanins_index_.reserve(node->input_size()); + while (pos <= last_pos) { + const TensorId& fanin_id = node_fanins[pos]; + bool is_control = IsTensorIdControl(fanin_id); + const int fanin_node_index = node_index_by_name_[fanin_id.node()]; + MutableNodeView& fanin_node_view = nodes_[fanin_node_index]; + + if (is_control) { + if (gtl::InsertIfNotPresent(&observed_controls, fanin_id.node())) { + fanin_node_view.controlled_fanouts_.emplace_back( + this, i, Graph::kControlSlot, + node_view.controlling_fanins_.size()); + node_view.controlling_fanins_.emplace_back( + this, fanin_node_index, Graph::kControlSlot, + fanin_node_view.controlled_fanouts_.size() - 1); + IncrementFaninCount( + &node_view.fanins_count_, + {&graph_->node(fanin_node_index), Graph::kControlSlot}); + node_view.controlling_fanins_index_.emplace( + fanin_id.node(), pos - node_view.NumRegularFanins()); + ++pos; + } else { + node->mutable_input()->SwapElements(pos, last_pos); + std::swap(node_fanins[pos], node_fanins[last_pos]); + --last_pos; + } + } else { + if (fanin_node_view.regular_fanouts_by_port_.size() < + fanin_id.index() + 1) { + fanin_node_view.regular_fanouts_by_port_.resize(fanin_id.index() + 1); + } + auto& fanin_regular_fanouts = + fanin_node_view.regular_fanouts_by_port_[fanin_id.index()]; + fanin_regular_fanouts.emplace_back(this, i, + node_view.regular_fanins_.size(), + node_view.regular_fanins_.size()); + ++fanin_node_view.num_regular_fanouts_; + node_view.regular_fanins_.emplace_back( + this, fanin_node_index, fanin_id.index(), + fanin_regular_fanouts.size() - 1); + IncrementFaninCount( + &node_view.fanins_count_, + {&graph_->node(fanin_node_index), fanin_id.index()}); + ++pos; + } + } + if (last_pos < last_idx) { + node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos); + } + } +} + +Status MutableGraphView::GetNodeNamesAndPartitionUpdatedNodes( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices) { + for (const auto& diff : mutation_.updated_nodes_) { + // For all nodes to be removed and renamed, mark their original names as + // missing and put associated node index in graph. + if (diff.removed || diff.update_name) { + const int index = diff.node_index; + const string& node_name = nodes_[index].GetName(); + node_names->emplace(node_name, index); + } + } + + auto name_conflict = [](const absl::string_view node_name) { + return errors::InvalidArgument(kMutableGraphViewApplyError, + "multiple nodes with the name: '", node_name, + "' exists in Mutation."); + }; + + // Partition updated nodes by if they will be renamed or not. + const int num_updated_nodes = mutation_.updated_nodes_.size(); + renamed_nodes->reserve(num_updated_nodes); + inplace_nodes->reserve(num_updated_nodes); + empty_diff_node_indices->reserve(num_updated_nodes); + for (int i = 0; i < num_updated_nodes; ++i) { + auto& diff = mutation_.updated_nodes_[i]; + if (internal::IsEmpty(&diff)) { + empty_diff_node_indices->emplace_back(diff.node_index); + continue; + } else if (diff.removed) { + continue; + } + // Get name of updated node after potential mutation. + const string& node_name = + diff.update_name ? diff.name : nodes_[diff.node_index].GetName(); + auto it = node_names->insert({node_name, internal::kNodeNamePresent}); + if (!it.second) { + if (it.first->second == internal::kNodeNamePresent) { + // Another node in the mutation is already using this name, which will + // result in a conflict. + return name_conflict(node_name); + } else { + // Mark name as present (node was marked missing from either being + // removed or renamed). + it.first->second = internal::kNodeNamePresent; + } + } + if (diff.update_name) { + // Lookup new name of node in current graph. If a node has such name, + // store its index for later lookups as this node will be overwritten. + auto node_name_it = node_index_by_name_.find(node_name); + const int overwritten_node_index = + node_name_it != node_index_by_name_.end() ? node_name_it->second + : internal::kMissingIndex; + renamed_nodes->emplace_back(i, overwritten_node_index); + } else { + inplace_nodes->push_back(i); + } + } + + // Get names of new nodes after potential mutation. + for (const auto& new_node : mutation_.new_nodes_) { + const string& node_name = new_node.node.name(); + auto it = node_names->insert({node_name, internal::kNodeNamePresent}); + if (it.second) { + continue; + } + if (it.first->second == internal::kNodeNamePresent) { + // Another node in the mutation is already using this name, which will + // result in a conflict. + return name_conflict(node_name); + } else { + // Mark name as present (node was marked missing from either being removed + // or renamed). + it.first->second = internal::kNodeNamePresent; + } + } + + return Status::OK(); +} + +Status MutableGraphView::RemovedOrMissingNodeFanoutsWellFormed( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes) { + auto bad_fanout = [](absl::string_view fanout_node_name, + absl::string_view node_name) { + return errors::InvalidArgument( + kMutableGraphViewApplyError, "fanout '", fanout_node_name, + "' exist for missing node '", node_name, "'."); + }; + + // Lookup nodes to be overwritten. + std::vector overwritten_nodes(NumNodes()); + for (auto& renamed_node : renamed_nodes) { + if (renamed_node.overwritten_node_index_ == internal::kMissingIndex) { + continue; + } + overwritten_nodes[renamed_node.overwritten_node_index_] = true; + } + + // Check if removed nodes and previous state of renamed nodes have no fanouts. + for (const auto& node_name_state : node_names) { + if (node_name_state.second == internal::kNodeNamePresent) { + continue; + } + const MutableNodeView& node_view = nodes_[node_name_state.second]; + for (const auto& regular_fanouts : node_view.GetRegularFanouts()) { + for (const auto& regular_fanout : regular_fanouts) { + // Check all fanouts of a single port. + MutableNodeView* fanout_view = regular_fanout.node_view(); + if (fanout_view->update_index_ == internal::kMissingIndex) { + if (!overwritten_nodes[fanout_view->node_index_]) { + // Fanout is not updated or removed/overwritten. + return bad_fanout(fanout_view->GetName(), node_name_state.first); + } + } else { + auto& diff = mutation_.updated_nodes_[fanout_view->update_index_]; + if (diff.removed) { + // Fanout node will be removed, this can be ignored. + continue; + } + const int last_index = fanout_view->NumRegularFanins() - + diff.num_regular_inputs_to_remove - 1; + if (regular_fanout.index() > last_index) { + // Fanin of fanout is removed, this can be ignored. + continue; + } + // Check if fanin is updated. + if (diff.regular_inputs_to_update.find(regular_fanout.index()) == + diff.regular_inputs_to_update.end()) { + return bad_fanout(fanout_view->GetName(), node_name_state.first); + } + } + } + } + for (const auto& controlled_fanout : node_view.GetControlledFanouts()) { + MutableNodeView* fanout_view = controlled_fanout.node_view(); + if (fanout_view->update_index_ == internal::kMissingIndex) { + if (!overwritten_nodes[fanout_view->node_index_]) { + // Fanout is not updated or removed/overwritten. + return bad_fanout(fanout_view->GetName(), node_name_state.first); + } + } else { + auto& diff = mutation_.updated_nodes_[fanout_view->update_index_]; + if (diff.removed) { + // Fanout node will be removed, this can be ignored. + continue; + } + // Check if controlling fanin is removed. + if (diff.controlling_inputs_to_remove.find( + controlled_fanout.fanin_index_) == + diff.controlling_inputs_to_remove.end()) { + return bad_fanout(fanout_view->GetName(), node_name_state.first); + } + } + } + } + + return Status::OK(); +} + +Status MutableGraphView::CheckNodeNamesAndFanins( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes, + const std::vector& inplace_nodes) { + // Check if removed/missing node fanouts are valid. + TF_RETURN_IF_ERROR( + RemovedOrMissingNodeFanoutsWellFormed(node_names, renamed_nodes)); + + // Check if updated nodes and their fanins are valid. + for (auto& inplace_node : inplace_nodes) { + auto& diff = mutation_.updated_nodes_[inplace_node]; + if (!internal::IsWellFormed(&diff, node_names)) { + return errors::InvalidArgument( + kMutableGraphViewApplyError, "inplace updated node '", + nodes_[diff.node_index].GetName(), "' is ill-formed."); + } + } + for (auto& renamed_node : renamed_nodes) { + auto& diff = mutation_.updated_nodes_[renamed_node.renamed_update_index_]; + if (!internal::IsWellFormed(&diff, node_names)) { + return errors::InvalidArgument( + kMutableGraphViewApplyError, "renamed updated node '", diff.name, + "' ('", nodes_[diff.node_index].GetName(), "') is ill-formed."); + } + } + + // Check if new nodes and their fanins are valid. + for (auto& new_node : mutation_.new_nodes_) { + if (!internal::IsWellFormed(&new_node, node_names)) { + return errors::InvalidArgument(kMutableGraphViewApplyError, "new node '", + new_node.node.name(), "' is ill-formed."); + } + } + + return Status::OK(); +} + +Status MutableGraphView::CheckKernelRegisteredForNodes() { + Status s; + for (auto& diff : mutation_.updated_nodes_) { + if (internal::IsEmpty(&diff) || diff.removed) { + continue; + } + + NodeDef* node = nodes_[diff.node_index].node(); + diff.processed_attrs = + AttrValueMap(node->attr().begin(), node->attr().end()); + for (const auto& attr_to_remove : diff.attrs_to_remove) { + diff.processed_attrs.erase(attr_to_remove); + } + for (const auto& attr_to_add : diff.attrs_to_add) { + gtl::InsertOrUpdate(&diff.processed_attrs, attr_to_add.first, + attr_to_add.second); + } + const string& device = diff.update_device ? diff.device : node->device(); + if (device.empty()) { + continue; + } + s = IsKernelRegisteredForNode(diff.update_name ? diff.name : node->name(), + node->has_experimental_debug_info(), + node->experimental_debug_info(), + diff.update_op ? diff.op : node->op(), device, + AttrSlice(&diff.processed_attrs)); + if (!s.ok()) { + return errors::InvalidArgument(kMutableGraphViewApplyError, + s.error_message()); + } + } + for (const auto& new_node_holder : mutation_.new_nodes_) { + const auto& new_node_def = new_node_holder.node; + if (new_node_def.device().empty()) { + continue; + } + s = IsKernelRegisteredForNode(new_node_def); + if (!s.ok()) { + return errors::InvalidArgument(kMutableGraphViewApplyError, + s.error_message()); + } + } + return Status::OK(); +} + +template +void MutableGraphView::ReplaceNodeFanouts(MutableNodeView* node, T* fanouts) { + node->num_regular_fanouts_ = fanouts->num_regular_fanouts_; + node->regular_fanouts_by_port_ = std::move(fanouts->regular_fanouts_by_port_); + for (int i = 0; i < node->regular_fanouts_by_port_.size(); ++i) { + for (int j = 0; j < node->regular_fanouts_by_port_[i].size(); ++j) { + auto& fanout = node->regular_fanouts_by_port_[i][j]; + auto* fanout_node_view = fanout.node_view(); + auto& fanout_fanin = fanout_node_view->regular_fanins_[fanout.index()]; + auto* fanout_fanins_count = &fanout_node_view->fanins_count_; + DecrementFaninCount( + fanout_fanins_count, + {&graph_->node(fanout_fanin.node_index_), fanout_fanin.index()}); + fanout_fanin.node_index_ = node->node_index_; + IncrementFaninCount( + fanout_fanins_count, + {&graph_->node(node->node_index_), fanout_fanin.index()}); + } + } + node->controlled_fanouts_ = std::move(fanouts->controlled_fanouts_); + for (int i = 0; i < node->controlled_fanouts_.size(); ++i) { + auto& fanout = node->controlled_fanouts_[i]; + auto* fanout_node_view = fanout.node_view(); + auto& fanout_fanin = + fanout_node_view->controlling_fanins_[fanout.fanin_index_]; + auto* fanout_fanins_count = &fanout_node_view->fanins_count_; + DecrementFaninCount( + fanout_fanins_count, + {&graph_->node(fanout_fanin.node_index_), Graph::kControlSlot}); + fanout_fanin.node_index_ = node->node_index_; + fanout_fanin.fanout_index_ = i; + IncrementFaninCount(fanout_fanins_count, {&graph_->node(node->node_index_), + Graph::kControlSlot}); + } +} + +void MutableGraphView::FixRenamedNodes( + std::vector* renamed_nodes, + absl::flat_hash_map* renamed_fanouts, + std::vector* overwritten_name_removed_nodes) { + // Extract all renamed node fanouts. + renamed_fanouts->reserve(renamed_nodes->size()); + for (auto& renamed : *renamed_nodes) { + auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_]; + // Remove node index by name from graph. + node_index_by_name_.erase(nodes_[diff.node_index].GetName()); + MutableNodeView& renamed_node = nodes_[diff.node_index]; + renamed_fanouts->try_emplace( + renamed_node.GetName(), + std::move(renamed_node.regular_fanouts_by_port_), + renamed_node.num_regular_fanouts_, + std::move(renamed_node.controlled_fanouts_)); + } + + // Replace renamed node fanouts with fanouts associated with updated name. + for (auto& renamed : *renamed_nodes) { + auto& diff = mutation_.updated_nodes_[renamed.renamed_update_index_]; + MutableNodeView& renamed_node = nodes_[diff.node_index]; + auto fanouts_it = renamed_fanouts->find(diff.name); + if (fanouts_it != renamed_fanouts->end()) { + // Another renamed node's fanout. + auto& fanouts = fanouts_it->second; + ReplaceNodeFanouts(&renamed_node, &fanouts); + renamed_fanouts->erase(fanouts_it); + // Node to be overwritten is being renamed, so it won't be overwritten. + renamed.overwritten_node_index_ = internal::kMissingIndex; + } else if (renamed.overwritten_node_index_ != internal::kMissingIndex) { + // Existing node in graph. + MutableNodeView& node_to_overwrite = + nodes_[renamed.overwritten_node_index_]; + ReplaceNodeFanouts(&renamed_node, &node_to_overwrite); + node_index_by_name_.erase(node_to_overwrite.GetName()); + if (node_to_overwrite.update_index_ != internal::kMissingIndex && + mutation_.updated_nodes_[node_to_overwrite.update_index_].removed) { + (*overwritten_name_removed_nodes)[node_to_overwrite.update_index_] = + true; + } + } else { + // No existing fanouts. + renamed_node.num_regular_fanouts_ = 0; + } + + // Update node name. + renamed_node.node()->set_name(diff.name); + diff.update_name = false; + diff.name.clear(); + // Rehash renamed nodes with updated name. + node_index_by_name_.emplace(renamed_node.GetName(), diff.node_index); + } +} + +void MutableGraphView::AddNewNodes( + absl::flat_hash_map* renamed_fanouts, + std::vector* new_node_indices) { + new_node_indices->reserve(mutation_.new_nodes_.size()); + for (auto& new_node : mutation_.new_nodes_) { + int node_index; + auto graph_it = node_index_by_name_.find(new_node.node.name()); + if (graph_it != node_index_by_name_.end()) { + // Overwrite existing node. + node_index = graph_it->second; + MutableNodeView& node_view = nodes_[node_index]; + RemoveAllFaninFanoutInternal(&node_view); + auto* node_def = graph_->mutable_node(node_index); + node_def->mutable_op()->swap(*new_node.node.mutable_op()); + node_def->mutable_device()->swap(*new_node.node.mutable_device()); + node_def->mutable_input()->Clear(); + node_def->mutable_attr()->swap(*new_node.node.mutable_attr()); + if (node_view.update_index_ != internal::kMissingIndex) { + // The only case for this to occur is if a node is explicitly marked for + // removal. In that case, unlink it from it's associated + // MutableNodeViewDiff. + mutation_.updated_nodes_[node_view.update_index_].node_index = + internal::kMissingIndex; + mutation_.updated_nodes_[node_view.update_index_].removed = false; + node_view.update_index_ = internal::kMissingIndex; + } + } else { + // New node. + auto* new_node_def = graph_->add_node(); + *new_node_def = std::move(new_node.node); + node_index = nodes_.size(); + nodes_.emplace_back(this, node_index); + MutableNodeView& new_node_view = nodes_.back(); + auto it = renamed_fanouts->find(new_node_view.GetName()); + if (it != renamed_fanouts->end()) { + // Reuse fanouts of renamed node. + NodeViewFanouts& fanouts = it->second; + ReplaceNodeFanouts(&new_node_view, &fanouts); + renamed_fanouts->erase(it); + } + node_index_by_name_.emplace(new_node_view.GetName(), node_index); + } + new_node_indices->emplace_back(node_index); + } +} + +void MutableGraphView::FixRenamedFanouts( + const absl::flat_hash_map& renamed_fanouts) { + // Leftover fanouts in renamed_fanouts are due to nodes not existing anymore + // or a node being renamed without another node taking its place. For these + // leftover fanouts, mark their respective fanin fanout_index_ to + // internal::kMissingIndex as an indicator so when it comes to updating or + // removing fanins inplace, nodes with the same index don't get affected and + // other fanouts are accidently removed. + for (auto& renamed_fanout : renamed_fanouts) { + for (auto& regular_fanouts : + renamed_fanout.second.regular_fanouts_by_port_) { + for (auto& fanout : regular_fanouts) { + auto* fanout_node_view = fanout.node_view(); + auto& fanin = fanout_node_view->regular_fanins_[fanout.index()]; + fanout_node_view->fanins_count_.erase( + {fanin.node_view()->node(), fanin.index()}); + fanin.fanout_index_ = internal::kMissingIndex; + } + } + for (auto& fanout : renamed_fanout.second.controlled_fanouts_) { + auto* fanout_node_view = fanout.node_view(); + auto& fanin = fanout_node_view->controlling_fanins_[fanout.fanin_index_]; + fanout_node_view->fanins_count_.erase( + {fanin.node_view()->node(), Graph::kControlSlot}); + fanout_node_view->controlling_fanins_index_.erase(renamed_fanout.first); + fanin.fanout_index_ = internal::kMissingIndex; + } + } +} + +inline void MutableGraphView::RemoveRegularFaninFanoutInternal( + MutableNodeView* node_view, int i) { + MutableFanoutView& fanin = node_view->regular_fanins_[i]; + // Fanin was marked as removed via FixRenamedFanouts. + if (fanin.fanout_index_ == internal::kMissingIndex) { + return; + } + + DecrementFaninCount(&node_view->fanins_count_, + {&graph_->node(fanin.node_index_), fanin.index()}); + auto* fanin_node_view = fanin.node_view(); + auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin.index()]; + if (fanin.fanout_index_ < fanouts.size() - 1) { + // Swap fanout with last fanout in vector, and update it's associated fanin + // index. + MutableFaninView& last_fanout = fanouts.back(); + last_fanout.node_view() + ->regular_fanins_[last_fanout.index()] + .fanout_index_ = fanin.fanout_index_; + std::swap(last_fanout, fanouts[fanin.fanout_index_]); + } + // Remove fanout. + fanouts.pop_back(); + --fanin.node_view()->num_regular_fanouts_; + + // Resize fanouts. Fanouts may not be removed sequentially in relation to + // output port, so trailing empty output ports may be left behind. It is + // necessary to loop through all of the output ports to determine the maximum + // output port before resizing. + int last_fanout_index = fanin_node_view->regular_fanouts_by_port_.size(); + for (int i = fanin_node_view->regular_fanouts_by_port_.size() - 1; i >= 0; + --i) { + if (fanin_node_view->regular_fanouts_by_port_[i].empty()) { + last_fanout_index = i; + } else { + break; + } + } + if (last_fanout_index < fanin_node_view->regular_fanouts_by_port_.size()) { + fanin_node_view->regular_fanouts_by_port_.resize(last_fanout_index); + } +} + +inline void MutableGraphView::AddRegularFaninInternal( + MutableNodeView* node_view, const SafeTensorId& fanin_id) { + MutableNodeView* fanin_node_view = GetNode(fanin_id.node()); + // Resize fanouts to include new output port index. + if (fanin_node_view->regular_fanouts_by_port_.size() < fanin_id.index() + 1) { + fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1); + } + + // Add node as fanout to fanin. + auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()]; + fanouts.emplace_back(this, node_view->node_index(), + node_view->regular_fanins_.size(), + node_view->regular_fanins_.size()); + ++fanin_node_view->num_regular_fanouts_; + + // Add fanin to node. + node_view->regular_fanins_.emplace_back(this, fanin_node_view->node_index(), + fanin_id.index(), fanouts.size() - 1); + IncrementFaninCount( + &node_view->fanins_count_, + {&graph_->node(fanin_node_view->node_index()), fanin_id.index()}); +} + +inline void MutableGraphView::UpdateRegularFaninInternal( + MutableNodeView* node_view, const int i, const SafeTensorId& fanin_id) { + // Remove fanin. + RemoveRegularFaninFanoutInternal(node_view, i); + + MutableNodeView* fanin_node_view = GetNode(fanin_id.node()); + // Resize fanouts to include new output port index. + if (fanin_node_view->regular_fanouts_by_port_.size() < fanin_id.index() + 1) { + fanin_node_view->regular_fanouts_by_port_.resize(fanin_id.index() + 1); + } + + // Add node as fanout to fanin. + auto& fanouts = fanin_node_view->regular_fanouts_by_port_[fanin_id.index()]; + fanouts.emplace_back(this, node_view->node_index(), i, i); + ++fanin_node_view->num_regular_fanouts_; + + // Replace fanin in node. + node_view->regular_fanins_[i] = + MutableFanoutView(this, fanin_node_view->node_index(), fanin_id.index(), + fanouts.size() - 1); + IncrementFaninCount( + &node_view->fanins_count_, + {&graph_->node(fanin_node_view->node_index()), fanin_id.index()}); +} + +inline void MutableGraphView::RemoveControllingFaninFanoutInternal( + MutableNodeView* node_view, int i) { + auto& control_to_remove = node_view->controlling_fanins_[i]; + if (control_to_remove.fanout_index_ != internal::kMissingIndex) { + // Update internal state associated with node. + node_view->fanins_count_.erase( + {control_to_remove.node_view()->node(), Graph::kControlSlot}); + node_view->controlling_fanins_index_.erase( + control_to_remove.node_view()->GetName()); + + // Remove controlled fanout from controlling fanin, via swapping last + // controlled fanout in controlling fanin with controlled fanout to be + // removed. + auto* control_to_remove_view = control_to_remove.node_view(); + if (control_to_remove.fanout_index_ < + control_to_remove_view->controlled_fanouts_.size() - 1) { + auto& control_to_remove_view_last_control = + control_to_remove_view->controlled_fanouts_.back(); + control_to_remove_view_last_control.node_view() + ->controlling_fanins_[control_to_remove_view_last_control + .fanin_index_] + .fanout_index_ = control_to_remove.fanout_index_; + std::swap(control_to_remove_view_last_control, + control_to_remove_view + ->controlled_fanouts_[control_to_remove.fanout_index_]); + } + control_to_remove_view->controlled_fanouts_.pop_back(); + } +} + +inline void MutableGraphView::RemoveControllingFaninInternal( + MutableNodeView* node_view, const std::set& indices_to_remove) { + const int num_regular_fanins = node_view->NumRegularFanins(); + auto* mutable_input = node_view->node()->mutable_input(); + // Iterate in descending order so indices stay consistent. + for (auto rit = indices_to_remove.rbegin(); rit != indices_to_remove.rend(); + ++rit) { + const int control_index = *rit; + RemoveControllingFaninFanoutInternal(node_view, control_index); + + // Swap last controlling fanin in node with controlling fanin to be removed. + if (control_index < node_view->controlling_fanins_.size() - 1) { + auto& last_control = node_view->controlling_fanins_.back(); + auto* last_control_view = last_control.node_view(); + last_control_view->controlled_fanouts_[last_control.fanout_index_] + .fanin_index_ = control_index; + node_view->controlling_fanins_index_.find(last_control_view->GetName()) + ->second = control_index; + mutable_input->SwapElements( + num_regular_fanins + control_index, + num_regular_fanins + node_view->NumControllingFanins() - 1); + std::swap(last_control, node_view->controlling_fanins_[control_index]); + } + mutable_input->RemoveLast(); + node_view->controlling_fanins_.pop_back(); + } +} + +inline void MutableGraphView::AddControllingFaninInternal( + MutableNodeView* node_view, absl::string_view fanin_node_name) { + NodeDef* node = node_view->node(); + // Add controlling fanin to NodeDef. + node->add_input(AsControlDependency(string(fanin_node_name))); + MutableNodeView* fanin_node_view = GetNode(fanin_node_name); + const int index = node_view->controlling_fanins_.size(); + fanin_node_view->controlled_fanouts_.emplace_back( + this, node_view->node_index(), Graph::kControlSlot, index); + node_view->controlling_fanins_.emplace_back( + this, fanin_node_view->node_index(), Graph::kControlSlot, + fanin_node_view->controlled_fanouts_.size() - 1); + IncrementFaninCount( + &node_view->fanins_count_, + {&graph_->node(fanin_node_view->node_index()), Graph::kControlSlot}); + // Parse new fanin string for node name. + TensorId tensor_id = ParseTensorName(node->input(node->input_size() - 1)); + node_view->controlling_fanins_index_.emplace(tensor_id.node(), index); +} + +void MutableGraphView::ApplyNodeUpdates() { + for (auto& diff : mutation_.updated_nodes_) { + if (diff.removed || diff.node_index == internal::kMissingIndex || + internal::IsEmpty(&diff)) { + continue; + } + MutableNodeView& node_view = nodes_[diff.node_index]; + diff.node_index = internal::kMissingIndex; + // Clean up node view. + node_view.update_index_ = internal::kMissingIndex; + + NodeDef* node_def = node_view.node(); + + // Set updated fields and attributes of node. + if (diff.update_op) { + node_def->set_op(diff.op); + } + if (diff.update_device) { + node_def->set_device(diff.device); + } + node_def->mutable_attr()->swap(diff.processed_attrs); + + // Updated fanins. Only one of `regular_inputs_to_remove_` or + // `regular_inputs_to_add_` can be set. + if (diff.num_regular_inputs_to_remove > 0) { + // Truncate trailing regular fanins. + const int first_index = + node_view.NumRegularFanins() - diff.num_regular_inputs_to_remove; + for (int i = first_index; i < node_view.NumRegularFanins(); ++i) { + RemoveRegularFaninFanoutInternal(&node_view, i); + } + node_view.regular_fanins_.resize(first_index); + node_def->mutable_input()->DeleteSubrange( + node_view.NumRegularFanins(), diff.num_regular_inputs_to_remove); + } else if (diff.num_regular_inputs_to_add > 0) { + // Append regular fanins. + node_def->mutable_input()->Reserve(node_def->mutable_input()->size() + + diff.num_regular_inputs_to_add); + int curr_index = node_view.NumRegularFanins(); + int curr_control_start = curr_index; + for (const SafeTensorId& fanin : diff.regular_inputs_to_add) { + AddRegularFaninInternal(&node_view, fanin); + node_def->add_input(SafeTensorIdToString(fanin)); + node_def->mutable_input()->SwapElements(curr_index, + node_def->input_size() - 1); + if (curr_control_start == curr_index) { + curr_control_start = node_def->input_size() - 1; + } + ++curr_index; + } + // Rotate shifted controlling fanins to match up with + // `node_view.controlling_fanins_` as `num_regular_inputs_to_add_` may not + // be a multiple of `num_regular_inputs_to_add_`. This is to prevent + // rehashing controlling fanins in `node_view.controlling_fanins_index_`. + if (node_view.NumControllingFanins() > 1 && + curr_control_start != node_view.NumRegularFanins()) { + std::rotate( + node_def->mutable_input()->begin() + node_view.NumRegularFanins(), + node_def->mutable_input()->begin() + curr_control_start, + node_def->mutable_input()->end()); + } + } + + for (const auto& update_fanin : diff.regular_inputs_to_update) { + UpdateRegularFaninInternal(&node_view, update_fanin.first, + update_fanin.second); + node_def->set_input(update_fanin.first, + SafeTensorIdToString(update_fanin.second)); + } + + RemoveControllingFaninInternal(&node_view, + diff.controlling_inputs_to_remove); + + node_def->mutable_input()->Reserve(node_def->mutable_input()->size() + + diff.controlling_inputs_to_add.size()); + for (const auto& control_to_add : diff.controlling_inputs_to_add) { + AddControllingFaninInternal(&node_view, control_to_add); + } + } +} + +void MutableGraphView::SetNewNodesFanins( + const std::vector& new_node_indices) { + auto new_node = mutation_.new_nodes_.begin(); + for (const int new_node_index : new_node_indices) { + MutableNodeView& new_node_view = nodes_[new_node_index]; + NodeDef* new_node_def = new_node_view.node(); + new_node_def->mutable_input()->Reserve(new_node->num_regular_fanins + + new_node->controlling_fanins.size()); + for (const SafeTensorId& fanin : new_node->regular_fanins) { + AddRegularFaninInternal(&new_node_view, fanin); + new_node_def->add_input(SafeTensorIdToString(fanin)); + } + for (const string& control_to_add : new_node->controlling_fanins) { + AddControllingFaninInternal(&new_node_view, control_to_add); + } + ++new_node; + } +} + +inline void MutableGraphView::RemoveAllFaninFanoutInternal( + MutableNodeView* node_view) { + const int num_regular_fanins = node_view->NumRegularFanins(); + for (int i = 0; i < num_regular_fanins; ++i) { + RemoveRegularFaninFanoutInternal(node_view, i); + } + std::vector().swap(node_view->regular_fanins_); + const int num_controlling_fanins = node_view->NumControllingFanins(); + for (int i = 0; i < num_controlling_fanins; ++i) { + RemoveControllingFaninFanoutInternal(node_view, i); + } + std::vector().swap(node_view->controlling_fanins_); +} + +void MutableGraphView::RemoveNodesInternal( + const std::vector& renamed_nodes, + const std::vector& overwritten_name_removed_nodes) { + // Get all nodes overwritten by renamed nodes and remove their fanins. + std::vector overwritten_nodes; + overwritten_nodes.reserve(renamed_nodes.size()); + for (const auto& renamed : renamed_nodes) { + if (renamed.overwritten_node_index_ != internal::kMissingIndex) { + auto& node = nodes_[renamed.overwritten_node_index_]; + RemoveAllFaninFanoutInternal(&node); + overwritten_nodes.emplace_back(renamed.overwritten_node_index_); + } + } + + // Get all nodes explicitly marked for removal and remove their fanins. + std::vector node_indices_to_remove; + node_indices_to_remove.reserve(mutation_.updated_nodes_.size() + + overwritten_nodes.size()); + for (int i = 0; i < mutation_.updated_nodes_.size(); ++i) { + const auto& diff = mutation_.updated_nodes_[i]; + if (diff.removed) { + auto& node = nodes_[diff.node_index]; + RemoveAllFaninFanoutInternal(&node); + node_indices_to_remove.push_back(diff.node_index); + if (!overwritten_name_removed_nodes[i]) { + node_index_by_name_.erase(node.GetName()); + } + } + } + node_indices_to_remove.insert(node_indices_to_remove.end(), + overwritten_nodes.begin(), + overwritten_nodes.end()); + std::set sorted_node_indices_to_remove(node_indices_to_remove.begin(), + node_indices_to_remove.end()); + + // Iterate in descending order so indices stay consistent. + for (auto rit = sorted_node_indices_to_remove.rbegin(); + rit != sorted_node_indices_to_remove.rend(); ++rit) { + const int removed_node_index = *rit; + MutableNodeView& last_node = nodes_.back(); + if (last_node.node_index_ > removed_node_index) { + last_node.node_index_ = removed_node_index; + for (auto& regular_fanin : last_node.regular_fanins_) { + // Update fanouts of regular fanins with new index. + regular_fanin.node_view() + ->regular_fanouts_by_port_[regular_fanin.index()] + [regular_fanin.fanout_index_] + .node_index_ = removed_node_index; + } + for (auto& controlling_fanin : last_node.controlling_fanins_) { + // Update fanouts of controlling fanins with new index. + controlling_fanin.node_view() + ->controlled_fanouts_[controlling_fanin.fanout_index_] + .node_index_ = removed_node_index; + } + for (auto& regular_fanouts : last_node.regular_fanouts_by_port_) { + for (auto& regular_fanout : regular_fanouts) { + // Update fanins of regular fanouts. + MutableNodeView* fanout_node_view = regular_fanout.node_view(); + fanout_node_view->regular_fanins_[regular_fanout.fanin_index_] + .node_index_ = removed_node_index; + } + } + for (auto& controlled_fanout : last_node.controlled_fanouts_) { + // Update fanins of controlled fanouts. + MutableNodeView* fanout_node_view = controlled_fanout.node_view(); + fanout_node_view->controlling_fanins_[controlled_fanout.fanin_index_] + .node_index_ = removed_node_index; + } + + const int last_node_index = nodes_.size() - 1; + std::swap(nodes_[last_node_index], nodes_[removed_node_index]); + graph()->mutable_node()->SwapElements(last_node_index, + removed_node_index); + node_index_by_name_.find(nodes_[removed_node_index].GetName())->second = + removed_node_index; + } + nodes_.pop_back(); + graph()->mutable_node()->RemoveLast(); + } +} + +inline Status MutableGraphView::ValidateInternal( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices) { + // Get node names and partition updated_nodes_ by if they are renamed or not, + // skipping empty MutableNodeViewDiff. + TF_RETURN_IF_ERROR(GetNodeNamesAndPartitionUpdatedNodes( + node_names, renamed_nodes, inplace_nodes, empty_diff_node_indices)); + + // Check existence of fanins and validity (i.e. no self loops). + TF_RETURN_IF_ERROR( + CheckNodeNamesAndFanins(*node_names, *renamed_nodes, *inplace_nodes)); + + // Check if nodes after mutation have kernels registered. + TF_RETURN_IF_ERROR(CheckKernelRegisteredForNodes()); + + return Status::OK(); +} + +Status MutableGraphView::ApplyMutationInternal() { + // Node name -> node index mapping. If a node index is -1, the associated node + // with key node name exists. Otherwise the node index is the node's index in + // the graph. + absl::flat_hash_map node_names; + // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are + // renamed (and possibly have other fields mutated). + std::vector renamed_nodes; + // Indices of MutableNodeViewDiff in Mutation::updated_nodes_ where nodes are + // not renamed but have fields mutated. + std::vector inplace_nodes; + // Indices of nodes in graph where MutableNodeViewDiff are empty. + // `update_index_` of nodes associated to empty MutableNodeViewDiff should be + // cleared after validation success. + std::vector empty_diff_node_indices; + + // Check if this mutation is valid before applying, and partition + // updated_nodes_ into inplace mutated nodes and renamed nodes. + TF_RETURN_IF_ERROR(ValidateInternal( + &node_names, &renamed_nodes, &inplace_nodes, &empty_diff_node_indices)); + + // Clear `update_index_` of MutableNodeView with empty associated + // MutableNodeViewDiff. + for (const int empty_diff_node_index : empty_diff_node_indices) { + nodes_[empty_diff_node_index].update_index_ = internal::kMissingIndex; + } + + // Node name and associated fanouts. + absl::flat_hash_map renamed_fanouts; + // Removed nodes where name was overwritten by a renamed node. + std::vector overwritten_name_removed_nodes; + overwritten_name_removed_nodes.resize(mutation_.updated_nodes_.size(), false); + // Fix renaming of existing nodes by swapping fanouts and rehashing names. + // This will also overwrite removed or unmodified nodes. + FixRenamedNodes(&renamed_nodes, &renamed_fanouts, + &overwritten_name_removed_nodes); + + // Indices of nodes in graph where new nodes were inserted/appended. These + // will be corresponding to `new_nodes_` in order. + std::vector new_node_indices; + // Add new nodes, overwriting removed or unmodified nodes. + AddNewNodes(&renamed_fanouts, &new_node_indices); + + // For abandoned fanouts, mark their respective fanins so the original node + // associated will not have their fanouts removed and be left in an + // inconsistent state. + FixRenamedFanouts(renamed_fanouts); + + // Apply mutations to updated nodes (renamed nodes are treated as inplace + // nodes as they have already been renamed). Removed nodes are ignored. + ApplyNodeUpdates(); + + // Set fanins of new nodes. + SetNewNodesFanins(new_node_indices); + + // Remove overwritten nodes and updated nodes set to be removed. + RemoveNodesInternal(renamed_nodes, overwritten_name_removed_nodes); + + mutation_.ResetInternal(); + + mutation_.mutation_counter_++; + + return Status::OK(); +} + +} // namespace utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/graph_view.h b/tensorflow/core/grappler/utils/graph_view.h new file mode 100644 index 00000000000..18f7c4ab560 --- /dev/null +++ b/tensorflow/core/grappler/utils/graph_view.h @@ -0,0 +1,490 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils/graph_view_internal.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace utils { + +class NodeView; + +class GraphView; + +// FaninView is a helper class to represent fanouts of a node. This holds a +// pointer to GraphView, the index of the node being represented from GraphView, +// and the input index (hence is labeled as Fanin). +class FaninView : public internal::NodeIndexAndPortIndex { + public: + FaninView() : NodeIndexAndPortIndex() {} + + FaninView(GraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + FaninView(NodeView* node_view, int index); + + private: + friend class NodeView; + friend class GraphView; +}; + +// FanoutView is a helper class to represent fanins of a node. This holds a +// pointer to GraphView, the index of the node being represented from GraphView, +// and the output index (hence is labeled as Fanout). +class FanoutView : public internal::NodeIndexAndPortIndex { + public: + FanoutView() : NodeIndexAndPortIndex() {} + + FanoutView(GraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + FanoutView(NodeView* node_view, int index); + + private: + friend class NodeView; + friend class GraphView; +}; + +// Immutable NodeView that keeps the constness of the NodeDef. This allows for +// lookups of fanins and fanouts, and traversals of the graph, but no mutations. +// No dedupping of fanins will be performed on the node to preserve it's +// constness. +class NodeView : public internal::NodeViewInternal { + public: + using NodeViewInternal::NodeViewInternal; + ~NodeView() override = default; + + const NodeDef* node() const override; + + // Checks if a fanin exists for the node. + bool HasFanin(const FanoutView& fanin) const override; + + // Checks if a fanout exists for the node. + bool HasFanout(const FaninView& fanout) const override; + + private: + inline const FanoutView& GetMissingFanin() const override; + + inline const std::vector& GetMissingFanout() const override; + + absl::flat_hash_set fanins_set_; + + friend class FaninView; + friend class FanoutView; + friend class GraphView; +}; + +// Immutable GraphView that keeps the constness of the GraphDef. This allows +// for lookups and traversals of the graph, but no mutations. +class GraphView : public internal::GraphViewInternal { + public: + explicit GraphView(const GraphDef* graph, Status* status); + ~GraphView() override = default; + + private: + bool AddUniqueNodeInternal(const NodeDef* node); + + Status CheckAndAddFaninsInternal(NodeView* node_view); + + friend class NodeView; +}; + +class MutableNodeView; + +class MutableGraphView; + +class Mutation; + +// MutableFaninView is a helper class to represent fanouts of a node. This holds +// a pointer to MutableGraphView, the index of the node from MutableGraphView +// being mutated, and the input index (hence is labeled as Fanin). +class MutableFaninView + : public internal::NodeIndexAndPortIndex { + public: + MutableFaninView() : NodeIndexAndPortIndex() {} + + MutableFaninView(MutableGraphView* graph_view, int node_index, int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + explicit MutableFaninView(MutableGraphView* graph_view, int node_index, + int port_index, int fanin_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index), + fanin_index_(fanin_index) { + // TODO(lyandy): Remove once constructor is not public. + DCHECK(port_index < 0 || port_index == fanin_index); + } + + MutableFaninView(MutableNodeView* node_view, int index); + + private: + // Index of associated fanin in fanout's underlying MutableNodeView. For + // regular fanouts, this will be the same as port_index (index of the + // associated fanin in MutableNodeView::regular_fanins_). For controlled + // fanouts, this will be the index of the associated fanin in + // MutableNodeView::controlling_fanins_. + int fanin_index_ = internal::kMissingIndex; + + friend class MutableNodeView; + friend class MutableGraphView; + friend class Mutation; +}; + +// MutableFanoutView is a helper class to represent fanins of a node. This holds +// a pointer to MutableGraphView, the index of the node from MutableGraphView +// being mutated, and the output index (hence is labeled as Fanout). +class MutableFanoutView + : public internal::NodeIndexAndPortIndex { + public: + MutableFanoutView() : NodeIndexAndPortIndex() {} + + MutableFanoutView(MutableGraphView* graph_view, int node_index, + int port_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index) {} + + explicit MutableFanoutView(MutableGraphView* graph_view, int node_index, + int port_index, int fanout_index) + : NodeIndexAndPortIndex(graph_view, node_index, port_index), + fanout_index_(fanout_index) {} + + MutableFanoutView(MutableNodeView* node_view, int index); + + private: + // Index of associated fanout in fanin's underlying MutableNodeView. For + // regular fanins, this will be the index of the associated fanout in + // MutableNodeView::regular_fanouts_by_port_[port_index]. For controlled + // fanins, this will be the index of the associated fanout in + // MutableNodeView::controlled_fanouts_. + int fanout_index_ = internal::kMissingIndex; + + friend class MutableNodeView; + friend class MutableGraphView; + friend class Mutation; +}; + +// Mutable NodeView that holds a mutable NodeDef. This allows for lookups of +// fanins and fanouts, and traversals of the graph. Control dependencies will be +// dedupped among other control dependencies on initialization via +// MutableGraphView. Mutations should be handled via MutableGraphView and not +// directly on the mutable NodeDef. +class MutableNodeView + : public internal::NodeViewInternal { + public: + using NodeViewInternal::NodeViewInternal; + ~MutableNodeView() override = default; + + NodeDef* node() const override; + + // Checks if a fanin exists for the node. + bool HasFanin(const MutableFanoutView& fanin) const override; + + // Checks if a fanout exists for the node. + bool HasFanout(const MutableFaninView& fanout) const override; + + private: + inline const MutableFanoutView& GetMissingFanin() const override; + + inline const std::vector& GetMissingFanout() const override; + + absl::flat_hash_map fanins_count_; + absl::flat_hash_map controlling_fanins_index_; + // Index of associated MutableNodeViewDiff in Mutation::updated_nodes_. + // If this is -1, there exists no MutableNodeViewDiff for this node. + int update_index_ = internal::kMissingIndex; + + friend class MutableFaninView; + friend class MutableFanoutView; + friend class MutableGraphView; + friend class Mutation; +}; + +class MutationNewNode { + private: + explicit MutationNewNode(Mutation* mutation, int mutation_counter, int index) + : mutation_(mutation), + mutation_counter_(mutation_counter), + index_(index) {} + + const Mutation* mutation_; + const int mutation_counter_; + const int index_; + + friend class Mutation; +}; + +// Mutation is a helper class that allows rewrites of MutableGraphView. This +// should not be initialized or be used directly. +// Note, if a node is renamed to another node, or a new node is created with the +// same name as an existing node, the node with the same name originally in the +// graph will be overwritten. +class Mutation { + public: + // Create a new node to be added to the graph. If the node's fanins are not + // well formed (self loops, control dependencies between regular fanins), the + // `status` will be set. + MutationNewNode AddNode(NodeDef&& node, Status* status); + + // Remove an existing node in the graph. + void RemoveNode(MutableNodeView* node); + + // Update the name of an existing node. + void UpdateNodeName(MutableNodeView* node, absl::string_view name); + + // Update the name of a new node. + void UpdateNodeName(const MutationNewNode& node, absl::string_view name); + + // Update the op of an existing node. + void UpdateNodeOp(MutableNodeView* node, absl::string_view op); + + // Update the op of a new node. + void UpdateNodeOp(const MutationNewNode& node, absl::string_view op); + + // Update the device of an existing node. + void UpdateNodeDevice(MutableNodeView* node, absl::string_view device); + + // Update the device of a new node. + void UpdateNodeDevice(const MutationNewNode& node, absl::string_view device); + + // Add or replace regular fanin `fanin` at `index` for an existing node. + void AddOrUpdateRegularFanin(MutableNodeView* node, int index, + const TensorId& fanin); + + // Add or replace regular fanin `fanin` at `index` for a new node. + void AddOrUpdateRegularFanin(const MutationNewNode& node, int index, + const TensorId& fanin); + + // Remove regular fanin at `index` for an existing node. + void RemoveRegularFanin(MutableNodeView* node, int index); + + // Remove regular fanin at `index` for a new node. + void RemoveRegularFanin(const MutationNewNode& node, int index); + + // Add controlling fanin `fanin_node_name` for an existing node. + void AddControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name); + + // Add controlling fanin `fanin_node_name` for a new node. + void AddControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name); + + // Remove controlling fanin `fanin_node_name` for an existing node. + void RemoveControllingFanin(MutableNodeView* node, + absl::string_view fanin_node_name); + + // Remove controlling fanin `fanin_node_name` for a new node. + void RemoveControllingFanin(const MutationNewNode& node, + absl::string_view fanin_node_name); + + // Add or replace attribute `attr_name` with `attr_value` for an existing + // node. + void AddOrUpdateNodeAttr(MutableNodeView* node, absl::string_view attr_name, + const AttrValue& attr_value); + + // Add or replace attribute `attr_name` with `attr_value` for a new node. + void AddOrUpdateNodeAttr(const MutationNewNode& node, + absl::string_view attr_name, + const AttrValue& attr_value); + + // Remove attribute `attr_name` for an existing node. + void RemoveNodeAttr(MutableNodeView* node, absl::string_view attr_name); + + // Remove attribute `attr_name` for a new node. + void RemoveNodeAttr(const MutationNewNode& node, absl::string_view attr_name); + + // Reset and clear mutation. + void Reset(); + + // Applies the Mutation to the graph. If the mutation is valid, the graph will + // be modified. Otherwise an error status will be returned and the graph will + // not be modified. + Status Apply(); + + private: + explicit Mutation(MutableGraphView* graph_view); + + void ResetInternal(); + + using MutableNodeViewDiff = internal::NodeViewDiff; + void AddMutation(MutableNodeView* node, + std::function mutate_fn); + + MutableGraphView* graph_view_ = nullptr; + int mutation_counter_ = 0; + std::vector updated_nodes_; + + using MutationNewNodeHolder = internal::NewNode; + std::vector new_nodes_; + + friend class MutableGraphView; +}; + +// Mutable GraphView that holds a mutable GraphDef. This allows for lookups and +// traversals of the graph. Control dependencies will be dedupped among other +// control dependencies on initialization. Mutations should be handled using +// this API instead of directly on the GraphDef/NodeDef. +// Note, after a mutation, pointers of MutableNodeView's from MutableGraphView +// may be invalidated. +class MutableGraphView + : public internal::GraphViewInternal { + public: + explicit MutableGraphView(GraphDef* graph, Status* status); + ~MutableGraphView() override = default; + + // Returns a Mutation (builder) that can be used to modify MutableGraphView. + Mutation* GetMutationBuilder(); + + private: + bool AddUniqueNodeInternal(NodeDef* node); + + Status CheckFaninsInternal(std::vector>* fanins); + + void AddFaninsInternal(std::vector>* fanins); + + // RenamedOrOverwrittenNode holds a index to Mutation::updated_nodes_ for a + // renamed node, alongside a potential overwritten node index in the actual + // graph. If the renamed node is not overwriting any existing nodes, + // `overwritten_node_index_` will be set to `internal::kMissingIndex`. + class RenamedOrOverwrittenNode { + public: + RenamedOrOverwrittenNode(int renamed_update_index, + int overwritten_node_index) + : renamed_update_index_(renamed_update_index), + overwritten_node_index_(overwritten_node_index) {} + + private: + int renamed_update_index_; + int overwritten_node_index_; + + friend class MutableGraphView; + }; + + Status GetNodeNamesAndPartitionUpdatedNodes( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices); + + Status RemovedOrMissingNodeFanoutsWellFormed( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes); + + Status CheckNodeNamesAndFanins( + const absl::flat_hash_map& node_names, + const std::vector& renamed_nodes, + const std::vector& inplace_nodes); + + Status CheckKernelRegisteredForNodes(); + + // Helper class to move fanouts around. + class NodeViewFanouts { + public: + NodeViewFanouts( + std::vector>&& regular_fanouts_by_port, + int num_regular_fanouts, + std::vector controlled_fanouts) + : regular_fanouts_by_port_(std::move(regular_fanouts_by_port)), + num_regular_fanouts_(num_regular_fanouts), + controlled_fanouts_(std::move(controlled_fanouts)) {} + + private: + std::vector> regular_fanouts_by_port_; + int num_regular_fanouts_ = 0; + std::vector controlled_fanouts_; + + friend class MutableGraphView; + }; + + template + void ReplaceNodeFanouts(MutableNodeView* node, T* fanouts); + + void FixRenamedNodes( + std::vector* renamed_nodes, + absl::flat_hash_map* renamed_fanouts, + std::vector* overwritten_name_removed_nodes); + + void AddNewNodes( + absl::flat_hash_map* renamed_fanouts, + std::vector* new_node_indices); + + void FixRenamedFanouts( + const absl::flat_hash_map& renamed_fanouts); + + inline void RemoveRegularFaninFanoutInternal(MutableNodeView* node_view, + int i); + + inline void AddRegularFaninInternal(MutableNodeView* node_view, + const SafeTensorId& fanin_id); + + inline void UpdateRegularFaninInternal(MutableNodeView* node_view, + const int i, + const SafeTensorId& fanin_id); + + inline void RemoveControllingFaninFanoutInternal(MutableNodeView* node_view, + int i); + + inline void RemoveControllingFaninInternal( + MutableNodeView* node_view, const std::set& indices_to_remove); + + inline void AddControllingFaninInternal(MutableNodeView* node_view, + absl::string_view fanin_node_name); + + void ApplyNodeUpdates(); + + void SetNewNodesFanins(const std::vector& new_node_indices); + + inline void RemoveAllFaninFanoutInternal(MutableNodeView* node_view); + + void RemoveNodesInternal( + const std::vector& renamed_nodes, + const std::vector& overwritten_name_removed_nodes); + + inline Status ValidateInternal( + absl::flat_hash_map* node_names, + std::vector* renamed_nodes, + std::vector* inplace_nodes, + std::vector* empty_diff_node_indices); + + Status ApplyMutationInternal(); + + Mutation mutation_; + + friend class MutableNodeView; + friend class Mutation; +}; + +} // namespace utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_H_ diff --git a/tensorflow/core/grappler/utils/graph_view_internal.h b/tensorflow/core/grappler/utils/graph_view_internal.h new file mode 100644 index 00000000000..22e15917fb4 --- /dev/null +++ b/tensorflow/core/grappler/utils/graph_view_internal.h @@ -0,0 +1,898 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { +namespace grappler { +namespace utils { +namespace internal { + +constexpr int kMissingSlot = -2; +constexpr int kMissingIndex = -1; +constexpr int kNodeNamePresent = -1; + +// NodeIndexAndPortIndex is a helper class that represents fanins and fanouts +// of a node. +template +class NodeIndexAndPortIndex { + public: + NodeIndexAndPortIndex() + : graph_view_(nullptr), + node_index_(kMissingIndex), + port_index_(kMissingSlot) {} + NodeIndexAndPortIndex(GraphViewT* graph_view, int node_index, int port_index) + : graph_view_(graph_view), + node_index_(node_index), + port_index_(port_index) {} + + bool operator==(const NodeIndexAndPortIndex& other) const { + return port_index_ == other.port_index_ && + node_index_ == other.node_index_ && graph_view_ == other.graph_view_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeIndexAndPortIndex& n) { + return Hash::combine(std::move(h), n.node_index_, n.port_index_); + } + + // Returns NodeView from `graph_view_` at `node_index_`. + NodeViewT* node_view() const { + if (graph_view_ == nullptr) { + return nullptr; + } + return graph_view_->GetNode(node_index_); + } + + // Returns node index in graph. + int node_index() const { return node_index_; } + + // Returns input/output port index. + int index() const { return port_index_; } + + protected: + GraphViewT* graph_view_; + int node_index_; + int port_index_; +}; + +// NodeDefAndPortIndex is a helper class that represents fanins hashed with +// pointer stability using the fanin's NodeDef. +class NodeDefAndPortIndex { + public: + NodeDefAndPortIndex(const NodeDef* node_def, int port_index) + : node_def_(node_def), port_index_(port_index) {} + + bool operator==(const NodeDefAndPortIndex& other) const { + return node_def_ == other.node_def_ && port_index_ == other.port_index_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeDefAndPortIndex& n) { + return Hash::combine(std::move(h), n.node_def_, n.port_index_); + } + + private: + const NodeDef* node_def_; + int port_index_; +}; + +// NodeViewInternal is a helper class to simplify graph traversal. It creates +// a view of a node and associated fanins and fanouts from the NodeDef +// protocol buffer. +// +// There are two public classes implementing NodeViewInternal: +// +// - NodeView: constructed from `const NodeDef` and doesn't allow mutating the +// underlying node. +// - MutableNodeView: constructed from `NodeDef` and allows mutating the +// underlying node. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Modifying the node outside of implementations of NodeViewInternal +// (i.e. modifying inputs of the NodeDef directly) may leave the NodeView +// in an inconsistent/invalid state. +// ----------------------------------------------------------------------------- +// +template +class NodeViewInternal { + private: + using NodeDefT = + typename std::conditional::type; + + public: + explicit NodeViewInternal(GraphViewT* graph_view, int node_index) + : graph_view_(graph_view), + node_index_(node_index), + attrs_(AttrSlice(graph_view->graph()->node(node_index))) {} + virtual ~NodeViewInternal() {} + + bool operator==(const NodeViewInternal& other) const { + return node_index_ == other.node_index_ && graph_view_ == other.graph_view_; + } + + template + friend Hash AbslHashValue(Hash h, const NodeViewInternal& n) { + return Hash::combine(std::move(h), n.node_index_); + } + + // Returns NodeDef of view. + virtual NodeDefT* node() const = 0; + + // Returns index of node in GraphDef/GraphView. + int node_index() const { return node_index_; } + + // Returns the name of the node. + const string& GetName() const { return node()->name(); } + + // Returns the op of the node. + const string& GetOp() const { return node()->op(); } + + // Returns the device set for the node. + const string& GetDevice() const { return node()->device(); } + + // Returns all regular fanins, based on ordering in the node. + const std::vector& GetRegularFanins() const { + return regular_fanins_; + } + + // Returns a regular fanin based on input index. If no such fanin exist, a + // missing fanin is returned, with no NodeView set and an index of -2. + const FanoutViewT& GetRegularFanin(int i) const { + if (i < 0 || i >= regular_fanins_.size()) { + return GetMissingFanin(); + } + return regular_fanins_[i]; + } + + // Returns all controlling fanins, based on ordering in the node. + const std::vector& GetControllingFanins() const { + return controlling_fanins_; + } + + // Returns all regular fanouts. + const std::vector>& GetRegularFanouts() const { + return regular_fanouts_by_port_; + } + + // Returns a regular fanout(s) based on output index. If no such output index + // exists, no fanouts will be returned. + const std::vector& GetRegularFanout(int i) const { + if (i < 0 || i >= regular_fanouts_by_port_.size()) { + return GetMissingFanout(); + } + return regular_fanouts_by_port_[i]; + } + + // Returns all controlled fanouts. + const std::vector& GetControlledFanouts() const { + return controlled_fanouts_; + } + + // Returns the number of regular fanins. + int NumRegularFanins() const { return regular_fanins_.size(); } + + // Returns the number of controlling fanins. + int NumControllingFanins() const { return controlling_fanins_.size(); } + + // Returns the number of regular fanouts. + int NumRegularFanouts() const { return num_regular_fanouts_; } + + // Returns the number of controlled fanouts. + int NumControlledFanouts() const { return controlled_fanouts_.size(); } + + // Checks if a fanin exists for the node. + virtual bool HasFanin(const FanoutViewT& fanin) const = 0; + + // Checks if a fanout exists for the node. + virtual bool HasFanout(const FaninViewT& fanout) const = 0; + + // Returns an attribute of the node by key. If no attribute for such key + // exists, a `nullptr` is returned. + const AttrValue* GetAttr(absl::string_view attr_name) const { + return attrs_.Find(attr_name); + } + + // Returns all attributes of the node. + const AttrSlice& GetAttrs() const { return attrs_; } + + // Returns the number of attributes in the node. + int NumAttrs() const { return attrs_.size(); } + + // Checks if an attribute exist in the node. + bool HasAttr(absl::string_view attr_name) const { + return attrs_.Find(attr_name) != nullptr; + } + + protected: + virtual inline const FanoutViewT& GetMissingFanin() const = 0; + virtual inline const std::vector& GetMissingFanout() const = 0; + + std::vector regular_fanins_; + std::vector controlling_fanins_; + std::vector> regular_fanouts_by_port_; + int num_regular_fanouts_ = 0; + std::vector controlled_fanouts_; + + GraphViewT* graph_view_; + int node_index_; + AttrSlice attrs_; +}; + +// GraphViewInternal is a helper class to simplify graph traversal. It creates +// a view of the nodes and associated fanins and fanouts from the GraphDef +// protocol buffer. +// +// There are two public classes implementing GraphViewInternal: +// +// - GraphView: constructed from `const GraphDef` and doesn't allow mutating +// the underlying graph and its nodes. +// - MutableGraphView: constructed from `GraphDef` and allows mutating the +// underlying graph and its nodes. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Modifying the graph outside of implementations of GraphViewInternal +// (i.e. removing nodes from the GraphDef directly) may lead to +// segfaults! Guaranteed by absl::string_view! +// ----------------------------------------------------------------------------- +// +template +class GraphViewInternal { + private: + using GraphDefT = + typename std::conditional::type; + + public: + explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} + virtual ~GraphViewInternal() {} + + bool operator==(const GraphViewInternal& other) const { + return graph_ == other.graph_; + } + + GraphDefT* graph() const { return graph_; } + + // Finds node by index in the graph. If no such node exists in the graph, a + // `nullptr` is returned. + const NodeViewT* GetNode(int node_index) const { + if (node_index < 0 || node_index >= nodes_.size()) { + return nullptr; + } + return &nodes_[node_index]; + } + + NodeViewT* GetNode(int node_index) { + if (node_index < 0 || node_index >= nodes_.size()) { + return nullptr; + } + return &nodes_[node_index]; + } + + // Finds node by name. If no such node exists in the graph, a `nullptr` is + // returned. + const NodeViewT* GetNode(absl::string_view node_name) const { + auto it = node_index_by_name_.find(node_name); + if (it == node_index_by_name_.end()) { + return nullptr; + } + return &nodes_[it->second]; + } + + NodeViewT* GetNode(absl::string_view node_name) { + auto it = node_index_by_name_.find(node_name); + if (it == node_index_by_name_.end()) { + return nullptr; + } + return &nodes_[it->second]; + } + + // Returns all nodes (as NodeView) in the graph. + const std::vector& GetNodes() const { return nodes_; } + + // Checks if a node by name exists in the graph. + bool HasNode(absl::string_view node_name) const { + return node_index_by_name_.contains(node_name); + } + + // Returns the number of nodes in the graph. + int NumNodes() const { return nodes_.size(); } + + protected: + // Reset allocated node vector and node map in case of failure. + void Reset() { + std::vector().swap(nodes_); + absl::flat_hash_map().swap(node_index_by_name_); + } + + // nodes_[i] is a view of graph_.{mutable_}node(i). + std::vector nodes_; + absl::flat_hash_map node_index_by_name_; + GraphDefT* graph_; + const FanoutViewT missing_fanin_; + const std::vector missing_fanout_; +}; + +inline SafeTensorId EmptyTensorId() { + return SafeTensorId("", internal::kMissingSlot); +} + +inline bool IsEmptyTensorId(const TensorId tensor_id) { + return tensor_id.node().empty() && + tensor_id.index() == internal::kMissingSlot; +} + +// NodeViewDiff is a helper struct holding changes to be made to an existing +// node in GraphViewT. This should not be initialized or be used directly. +template +struct NodeViewDiff { + explicit NodeViewDiff(GraphViewT* graph_view, int node_index) + : graph_view(graph_view), node_index(node_index) {} + + GraphViewT* graph_view; + int node_index; + bool removed = false; + string name; + bool update_name = false; + string op; + bool update_op = false; + string device; + bool update_device = false; + // Fanins to append after existing regular fanins. + std::vector regular_inputs_to_add; + // Number of fanins to be appended. This is used for a quick comparison with + // `regular_inputs_to_add` for if there will be any missing inputs in the + // updated node. + int num_regular_inputs_to_add = 0; + // Fanins to update inplace. + std::map regular_inputs_to_update; + // Fanins from end of regular fanins to remove. This keeps track of existing + // regular fanins in the original node to remove. + std::vector regular_inputs_to_remove; + // Number of fanins marked for removal. This is used for a quick comparison + // with `regular_inputs_to_remove` for if there will be any missing inputs + // in the updated node. + int num_regular_inputs_to_remove = 0; + absl::flat_hash_set controlling_inputs_to_add; + std::set controlling_inputs_to_remove; + absl::flat_hash_map attrs_to_add; + absl::flat_hash_set attrs_to_remove; + AttrValueMap processed_attrs; +}; + +// Sets node for removal via diff. +template +inline void SetRemoved(NodeViewDiff* diff, bool removed) { + diff->removed = removed; +} + +// Updates node name. If `name` is the same as the name in the original node, +// the field will be cleared in the diff. +template +inline void UpdateName(NodeViewDiff* diff, absl::string_view name) { + if (diff->graph_view->GetNode(diff->node_index)->GetName() == name) { + diff->name.clear(); + diff->update_name = false; + } else { + diff->name = string(name); + diff->update_name = true; + } +} + +// Updates node op. If `op` is the same as the op in the original node, the +// field will be cleared in the diff. +template +inline void UpdateOp(NodeViewDiff* diff, absl::string_view op) { + if (diff->graph_view->GetNode(diff->node_index)->GetOp() == op) { + diff->op.clear(); + diff->update_op = false; + } else { + diff->op = string(op); + diff->update_op = true; + } +} + +// Updates node device. If `device` is the same as the device in the original +// node, the field will be cleared in the diff. +template +inline void UpdateDevice(NodeViewDiff* diff, + absl::string_view device) { + if (diff->graph_view->GetNode(diff->node_index)->GetDevice() == device) { + diff->device.clear(); + diff->update_device = false; + } else { + diff->device = string(device); + diff->update_device = true; + } +} + +// Adds or updates value in vector `v` at index `i`. This will also resize the +// vector if index `i` is out of bounds, padding the vector with +// `default_value`. Returns true if a new value was appended or if an update +// occurred where an existing value was changed from `default_value`. +template +inline bool AddOrUpdateAtIndex(std::vector* v, int i, const U& value, + const T& default_value) { + if (i > v->size()) { + // Resize to include `value`, filling the newly introduced gap with + // `default_value` for later checks of validity (gaps in vector). + v->reserve(i + 1); + v->resize(i, default_value); + v->push_back({value}); + } else if (i == v->size()) { + // Vector is large enough, simply append `value` to the end. + v->push_back({value}); + } else { + // Update existing value. + bool updated = (*v)[i] == default_value; + (*v)[i] = {value}; + return updated; + } + return true; +} + +// Checks if a node with name `node_name` will exist in the final mutated graph. +template +inline bool CheckNodeNameExists( + absl::string_view node_name, + const absl::flat_hash_map& updated_node_names, + const GraphViewT* graph_view) { + auto it = updated_node_names.find(node_name); + if (it != updated_node_names.end()) { + return it->second == kNodeNamePresent; + } + return graph_view->HasNode(node_name); +} + +// Adds or updates regular fanin at `index` of regular fanins. If `index` is +// less than the number of regular fanins in the original node, the fanin at +// `index` in the original node will be updated with `fanin` if the fanin +// differs. If `index` is greater than or equal to the number of regular fanins, +// `fanin` will be added beyond the end of regular fanins at `index`. +template +inline void AddOrUpdateRegularFanin(NodeViewDiff* diff, int index, + const TensorId& fanin) { + if (index < 0) { + // Not a valid index for regular fanins. + return; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const int num_regular_fanins = node_view->NumRegularFanins(); + if (index < num_regular_fanins) { // Updating existing fanins. + // Calculate (relative) index from end of regular fanins, from absolute + // index from beginning of regular fanins. + const int relative_removal_index = num_regular_fanins - index - 1; + // Check if at relative index fanin was already marked for removal. + if (relative_removal_index < diff->regular_inputs_to_remove.size() && + diff->regular_inputs_to_remove[relative_removal_index]) { + // Unmark fanin for removal. + diff->regular_inputs_to_remove[relative_removal_index] = false; + --diff->num_regular_inputs_to_remove; + } + const auto& existing_fanin = node_view->GetRegularFanin(index); + if (existing_fanin.index() != fanin.index() || + existing_fanin.node_view()->GetName() != fanin.node()) { + // Update fanin if it is different from original fanin in node. + gtl::InsertOrUpdate(&diff->regular_inputs_to_update, index, + SafeTensorId(fanin)); + } + } else { + // Add fanin beyond current fanin range. + const int relative_add_index = index - num_regular_fanins; + if (AddOrUpdateAtIndex(&diff->regular_inputs_to_add, relative_add_index, + fanin, EmptyTensorId())) { + // New fanin was added. + ++diff->num_regular_inputs_to_add; + } + } +} + +// Remove regular fanin at `index` of regular fanins. This can remove existing +// fanins and updated/added fanins via AddOrUpdateRegularFanins. +template +inline void RemoveRegularFanin(NodeViewDiff* diff, int index) { + if (index < 0) { + // Not a valid index for regular fanins. + return; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const int num_regular_fanins = node_view->NumRegularFanins(); + if (index < num_regular_fanins) { // Removing existing fanins. + // Remove updated fanin if it exists. + diff->regular_inputs_to_update.erase(index); + // Calculate (relative) index from end of regular fanins, from absolute + // index from beginning of regular fanins. + const int relative_removal_index = num_regular_fanins - index - 1; + if (AddOrUpdateAtIndex(&diff->regular_inputs_to_remove, + relative_removal_index, + /*value=*/true, /*default_value=*/false)) { + ++diff->num_regular_inputs_to_remove; + } + } else { + // Relative index from end of regular fanins. + const int relative_add_index = index - num_regular_fanins; + if (relative_add_index >= diff->regular_inputs_to_add.size() || + IsEmptyTensorId(diff->regular_inputs_to_add[relative_add_index])) { + // At relative index, appended regular fanin was already marked for + // removal. + return; + } + // Remove added fanin. + diff->regular_inputs_to_add[relative_add_index] = EmptyTensorId(); + --diff->num_regular_inputs_to_add; + } +} + +// Adds controlling fanin. If the controlling fanin already exists in the +// original node, it will be dedupped. If the controlling fanin is marked for +// removal, this will reverse it. +template +inline void AddControllingFanin(NodeViewDiff* diff, + int control_index, + absl::string_view fanin_node_name) { + if (control_index == kMissingIndex) { + diff->controlling_inputs_to_add.emplace(fanin_node_name); + } else { + diff->controlling_inputs_to_remove.erase(control_index); + } +} + +// Remove controlling fanin. If the controlling fanin does not exist in the +// original node and diff, nothing will happen. If the controlling fanin exists +// in the diff, it will be removed. Otherwise the controlling fanin will be +// marked for removal from the original node. +template +inline void RemoveControllingFanin(NodeViewDiff* diff, + int control_index, + absl::string_view fanin_node_name) { + if (control_index == kMissingIndex) { + diff->controlling_inputs_to_add.erase(fanin_node_name); + } else { + diff->controlling_inputs_to_remove.emplace(control_index); + } +} + +// Adds or updates an attribute by name. If an attribute exist in the original +// node or diff (including those marked for removal), this will overwrite it. +template +inline void AddOrUpdateAttribute(NodeViewDiff* diff, + absl::string_view attr_name, + const AttrValue& attr_value) { + diff->attrs_to_remove.erase(attr_name); + gtl::InsertOrUpdate(&diff->attrs_to_add, string(attr_name), attr_value); +} + +// Removes an attribute by name. If an attribute exist in the original node or +// diff, this will remove it. +template +inline void RemoveAttribute(NodeViewDiff* diff, + absl::string_view attr_name) { + diff->attrs_to_add.erase(attr_name); + auto* node_view = diff->graph_view->GetNode(diff->node_index); + if (node_view->HasAttr(attr_name)) { + diff->attrs_to_remove.emplace(attr_name); + } +} + +// Removes trailing values in vector `v` for values equal to `value`. +template +inline void ResizeByTrimmingEndForValue(std::vector* v, const T& value) { + int curr_index = v->size(); + const int last_index = v->size() - 1; + for (int i = last_index; i >= 0; --i) { + if ((*v)[i] == value) { + curr_index = i; + } else { + break; + } + } + if (curr_index <= last_index) { + v->resize(curr_index); + } +} + +// Checks if any changes are set in the diff. +template +inline bool IsEmpty(NodeViewDiff* diff) { + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false); + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId()); + return !diff->removed && !diff->update_name && !diff->update_op && + !diff->update_device && diff->regular_inputs_to_add.empty() && + diff->regular_inputs_to_update.empty() && + diff->regular_inputs_to_remove.empty() && + diff->controlling_inputs_to_add.empty() && + diff->controlling_inputs_to_remove.empty() && + diff->attrs_to_add.empty() && diff->attrs_to_remove.empty(); +} + +// Resets and clears existing diff. +template +inline void Reset(NodeViewDiff* diff) { + diff->removed = false; + diff->name.clear(); + diff->update_name = false; + diff->op.clear(); + diff->update_op = false; + diff->device.clear(); + diff->update_device = false; + std::vector().swap(diff->regular_inputs_to_add); + diff->num_regular_inputs_to_add = false; + std::map().swap(diff->regular_inputs_to_update); + std::vector().swap(diff->regular_inputs_to_remove); + diff->num_regular_inputs_to_remove = 0; + absl::flat_hash_set().swap(diff->controlling_inputs_to_add); + std::set().swap(diff->controlling_inputs_to_remove); + absl::flat_hash_map().swap(diff->attrs_to_add); + absl::flat_hash_set().swap(diff->attrs_to_remove); +} + +// Checks if changes to node will result in a valid node. +template +inline bool IsWellFormed( + NodeViewDiff* diff, + const absl::flat_hash_map& updated_node_names) { + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_remove, false); + ResizeByTrimmingEndForValue(&diff->regular_inputs_to_add, EmptyTensorId()); + if (diff->regular_inputs_to_add.size() != diff->num_regular_inputs_to_add) { + // Missing regular fanins in between appended fanins. + return false; + } else if (diff->num_regular_inputs_to_add > 0 && + !diff->regular_inputs_to_remove.empty()) { + // Appending new fanins while removing existing fanins, resulting in missing + // regular fanins in between. + return false; + } else if (diff->regular_inputs_to_remove.size() != + diff->num_regular_inputs_to_remove) { + // Regular fanins exist in between removed fanins. + return false; + } + auto* node_view = diff->graph_view->GetNode(diff->node_index); + const string& node_name = + diff->update_name ? diff->name : node_view->GetName(); + auto invalid_node_name = [diff, updated_node_names, + node_name](absl::string_view fanin_node_name) { + return fanin_node_name == node_name || + !CheckNodeNameExists(fanin_node_name, updated_node_names, + diff->graph_view); + }; + + // Check if nodes of all updated and new fanins exist (from name) and if such + // fanins do not introduce self loops. Note, this will not check for if + // unmodified fanins exist. + if (diff->update_name) { + // If name of node was changed in node, check all fanins. Updated fanins are + // checked for existence and self loops. Unmodified fanins are checked for + // self loops. + // `regular_inputs_to_update`, `controlling_inputs_to_remove` are sorted, + // so iterators from these maps/sets can be incremented alongside iteration + // and be used for comparisons. + const int last_index = + node_view->NumRegularFanins() - diff->num_regular_inputs_to_remove - 1; + auto regular_to_update_it = diff->regular_inputs_to_update.begin(); + for (int i = 0; i <= last_index; ++i) { + if (regular_to_update_it != diff->regular_inputs_to_update.end() && + regular_to_update_it->first < i) { + ++regular_to_update_it; + } + if (regular_to_update_it != diff->regular_inputs_to_update.end() && + regular_to_update_it->first == i) { + if (invalid_node_name(regular_to_update_it->second.node())) { + return false; + } + } else { + const string& regular_name = + node_view->GetRegularFanin(i).node_view()->GetName(); + if (regular_name == node_name) { + return false; + } + } + } + + auto& controls = node_view->GetControllingFanins(); + const int num_controls = controls.size(); + auto control_to_remove_it = diff->controlling_inputs_to_remove.begin(); + for (int i = 0; i < num_controls; ++i) { + if (control_to_remove_it != diff->controlling_inputs_to_remove.end() && + *control_to_remove_it < i) { + ++control_to_remove_it; + } + if (control_to_remove_it != diff->controlling_inputs_to_remove.end() && + *control_to_remove_it == i) { + // Control dependency marked for removal, can be ignored. + continue; + } else if (controls[i].node_view()->GetName() == node_name) { + return false; + } + } + } else { + // Name of node was not changed, check only updated fanins under the + // assumption prior fanins were valid. + for (const auto& updated : diff->regular_inputs_to_update) { + const string& fanin_name = updated.second.node(); + if (invalid_node_name(fanin_name)) { + return false; + } + } + } + // Check appended regular fanins. + for (const auto& regular : diff->regular_inputs_to_add) { + if (invalid_node_name(regular.node())) { + return false; + } + } + // Check new controlling fanins. + for (const auto& control : diff->controlling_inputs_to_add) { + if (invalid_node_name(control)) { + return false; + } + } + + return true; +} + +// NewNode is a helper struct holding a new node to be added to a GraphViewT. +// This should not be initialized or be used directly. +template +struct NewNode { + explicit NewNode(GraphViewT* graph_view, NodeDef&& node) + : graph_view(graph_view), node(std::move(node)) {} + + GraphViewT* graph_view; + NodeDef node; + std::vector regular_fanins; + int num_regular_fanins = 0; + absl::flat_hash_set controlling_fanins; +}; + +// Updates new node name. +template +inline void UpdateName(NewNode* new_node, absl::string_view name) { + if (name.empty()) { + new_node->node.clear_name(); + } else { + new_node->node.set_name(string(name)); + } +} + +// Updates new node op. +template +inline void UpdateOp(NewNode* new_node, absl::string_view op) { + if (op.empty()) { + new_node->node.clear_op(); + } else { + new_node->node.set_op(string(op)); + } +} + +// Updates new node device. +template +inline void UpdateDevice(NewNode* new_node, + absl::string_view device) { + if (device.empty()) { + new_node->node.clear_device(); + } else { + new_node->node.set_device(string(device)); + } +} + +// Adds or updates regular fanin at `index` of regular fanins in the new node. +// If another fanin already exists at `index`, it will be replaced with `fanin`. +template +inline void AddOrUpdateRegularFanin(NewNode* new_node, int index, + const TensorId& fanin) { + if (index < 0) { + // Not a valid index for regular fanins. + return; + } else if (AddOrUpdateAtIndex(&new_node->regular_fanins, index, fanin, + EmptyTensorId())) { + ++new_node->num_regular_fanins; + } +} + +// Remove regular fanin at `index` of regular fanins in the new node. This can +// remove existing fanins and updated/added fanins via AddOrUpdateRegularFanins. +template +inline void RemoveRegularFanin(NewNode* new_node, int index) { + if (index < 0 || index >= new_node->regular_fanins.size() || + IsEmptyTensorId(new_node->regular_fanins[index])) { + return; + } + new_node->regular_fanins[index] = EmptyTensorId(); + --new_node->num_regular_fanins; +} + +// Adds controlling fanin to new node. +template +inline void AddControllingFanin(NewNode* new_node, + absl::string_view fanin_node_name) { + new_node->controlling_fanins.emplace(fanin_node_name); +} + +// Removes controlling fanin to new node. +template +inline void RemoveControllingFanin(NewNode* new_node, + absl::string_view fanin_node_name) { + new_node->controlling_fanins.erase(fanin_node_name); +} + +// Adds or updates an attribute by name to a new node. +template +inline void AddOrUpdateAttribute(NewNode* new_node, + absl::string_view attr_name, + const AttrValue& attr_value) { + gtl::InsertOrUpdate(new_node->node.mutable_attr(), string(attr_name), + attr_value); +} + +// Removes an attribute by name to a new node. +template +inline void RemoveAttribute(NewNode* new_node, + absl::string_view attr_name) { + new_node->node.mutable_attr()->erase(string(attr_name)); +} + +// Checks if current state of new node is a valid node. +template +inline bool IsWellFormed( + NewNode* new_node, + const absl::flat_hash_map& updated_node_names) { + ResizeByTrimmingEndForValue(&new_node->regular_fanins, EmptyTensorId()); + if (new_node->regular_fanins.size() != new_node->num_regular_fanins) { + return false; + } + + const string& node_name = new_node->node.name(); + auto invalid_node_name = [new_node, updated_node_names, + node_name](absl::string_view fanin_node_name) { + return fanin_node_name == node_name || + !CheckNodeNameExists(fanin_node_name, updated_node_names, + new_node->graph_view); + }; + // Check if nodes of all fanins exist (from name) and if fanins do not + // introduce self loops. + for (const auto& regular : new_node->regular_fanins) { + if (invalid_node_name(regular.node())) { + return false; + } + } + for (const auto& control : new_node->controlling_fanins) { + if (invalid_node_name(control)) { + return false; + } + } + + return true; +} + +} // namespace internal +} // namespace utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPH_VIEW_INTERNAL_H_ diff --git a/tensorflow/core/grappler/utils/graph_view_internal_test.cc b/tensorflow/core/grappler/utils/graph_view_internal_test.cc new file mode 100644 index 00000000000..cb959aea16b --- /dev/null +++ b/tensorflow/core/grappler/utils/graph_view_internal_test.cc @@ -0,0 +1,1112 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/graph_view_internal.h" + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/grappler/utils/graph_view.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace utils { +namespace internal { +namespace { + +using ::tensorflow::test::function::GDef; +using ::tensorflow::test::function::NDef; + +constexpr char kNodeOp[] = "NotImportant"; + +GraphDef SimpleTestGraphForMutation() { + return GDef( + {NDef("a", kNodeOp, {}), NDef("b", kNodeOp, {}), NDef("c", kNodeOp, {}), + NDef("d", kNodeOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, "device_d")}, + /*funcs=*/{}); +} + +absl::flat_hash_map GetUpdatedNodeNames( + const MutableGraphView* graph_view) { + absl::flat_hash_map updated_node_names; + updated_node_names.reserve(graph_view->NumNodes()); + for (const auto& node_view : graph_view->GetNodes()) { + updated_node_names.emplace(node_view.GetName(), -1); + } + return updated_node_names; +} + +using MutableNodeViewDiff = NodeViewDiff; + +TEST(MutableNodeViewDiffTest, SetRemoved) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + SetRemoved(&diff, true); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + SetRemoved(&diff, false); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, UpdateName) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateName(&diff, "d"); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, UpdateOp) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateOp(&diff, "RandomOp"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateOp(&diff, kNodeOp); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, UpdateDevice) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateDevice(&diff, "random_device"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + UpdateDevice(&diff, "device_d"); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, AddOrUpdateRegularFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Bad index. + AddOrUpdateRegularFanin(&diff, -1, {"a", 0}); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Set fanin to same existing fanin. + AddOrUpdateRegularFanin(&diff, 0, {"a", 2}); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Update existing fanin. + AddOrUpdateRegularFanin(&diff, 0, {"a", 3}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 4 resulting in missing fanin at index 3. + AddOrUpdateRegularFanin(&diff, 4, {"b", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 3. + AddOrUpdateRegularFanin(&diff, 3, {"c", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 5. + AddOrUpdateRegularFanin(&diff, 5, {"c", 5}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, AddOrUpdateRegularFaninBetweenRemovedFanins) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveRegularFanin(&diff, 0); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + RemoveRegularFanin(&diff, 2); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 1, {"c", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 0, {"c", 0}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveRegularFanin(&diff, 0); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 2, {"c", 2}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RemoveRegularFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Bad index. + RemoveRegularFanin(&diff, -1); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + RemoveRegularFanin(&diff, 3); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 4 resulting in missing fanin at index 3. + AddOrUpdateRegularFanin(&diff, 4, {"b", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 4. + RemoveRegularFanin(&diff, 4); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 4 resulting in missing fanin at index 3. + AddOrUpdateRegularFanin(&diff, 4, {"b", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Add new fanin at index 3. + AddOrUpdateRegularFanin(&diff, 3, {"c", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 3. + RemoveRegularFanin(&diff, 3); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 4. + RemoveRegularFanin(&diff, 4); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Add new fanin at index 5 resulting in missing fanin at indices 3 and 4. + AddOrUpdateRegularFanin(&diff, 5, {"b", 6}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Add new fanin at index 3 resulting in missing fanin at index 4. + AddOrUpdateRegularFanin(&diff, 3, {"c", 4}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove missing fanin at index 4. + RemoveRegularFanin(&diff, 4); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 3. + RemoveRegularFanin(&diff, 3); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 5. + RemoveRegularFanin(&diff, 5); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Update existing fanin. + AddOrUpdateRegularFanin(&diff, 1, {"a", 3}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + // Remove fanin at index 1. + RemoveRegularFanin(&diff, 1); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Add original fanin at index 1. + AddOrUpdateRegularFanin(&diff, 1, {"b", 3}); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RemoveRegularFaninResize) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"c", 5}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + AddOrUpdateRegularFanin(&diff, 4, {"c", 6}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + AddOrUpdateRegularFanin(&diff, 5, {"c", 7}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + // Remove fanin in middle of appended regular fanins. + RemoveRegularFanin(&diff, 4); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); + // Remove last fanin in appended regular fanins. + RemoveRegularFanin(&diff, 5); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, AddControllingFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, 0, "c"); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "a"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RemoveControllingFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "a"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveControllingFanin(&diff, 0, "c"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveControllingFanin(&diff, kMissingIndex, "a"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, 0, "c"); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, AddOrUpdateAttribute) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AttrValue attr_1; + attr_1.set_b(true); + AddOrUpdateAttribute(&diff, "attr_1", attr_1); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AttrValue attr_3; + attr_3.set_i(4); + AddOrUpdateAttribute(&diff, "attr_1", attr_3); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RemoveAttribute) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AttrValue attr_1; + attr_1.set_b(true); + AddOrUpdateAttribute(&diff, "attr_1", attr_1); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveAttribute(&diff, "attr_1"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveAttribute(&diff, "attr_3"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, Reset) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveRegularFanin(&diff, 2); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "a"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AttrValue attr_1; + attr_1.set_b(true); + AddOrUpdateAttribute(&diff, "attr_1", attr_1); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + Reset(&diff); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedWithRemovedAndAppendedFanins) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + RemoveRegularFanin(&diff, 2); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"a", 8}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedSelfLoopRegularUpdate) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 0, {"d", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedSelfLoopRegularNew) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"d", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedSelfLoopControl) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "d"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedMissingFaninRegularUpdate) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 0, {"e", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedMissingFaninRegularNew) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"e", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedMissingControl) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedSelfLoopRegularUpdate) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 0, {"e", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedSelfLoopRegularNew) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"e", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedSelfLoopControl) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedMissingFaninRegularUpdate) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 0, {"f", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedMissingFaninRegularNew) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddOrUpdateRegularFanin(&diff, 3, {"f", 1}); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, IsWellFormedRenamedMissingFaninControl) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + AddControllingFanin(&diff, kMissingIndex, "f"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RenamedAndRemovedFanins) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + string old_node_name = "d"; + string new_node_name = "e"; + updated_node_names.erase(old_node_name); + updated_node_names.emplace(old_node_name, 3); + updated_node_names.emplace(new_node_name, -1); + + UpdateName(&diff, "e"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + for (int i = 0; i < 3; ++i) { + RemoveRegularFanin(&diff, i); + } + RemoveControllingFanin(&diff, 0, "c"); + RemoveControllingFanin(&diff, 0, "b"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); +} + +TEST(MutableNodeViewDiffTest, RenamedWithSelfLoopControl) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + MutableNodeViewDiff diff(&graph_view, d_node->node_index()); + EXPECT_TRUE(IsEmpty(&diff)); + EXPECT_TRUE(IsWellFormed(&diff, updated_node_names)); + + updated_node_names.erase("d"); + + UpdateName(&diff, "c"); + EXPECT_FALSE(IsEmpty(&diff)); + EXPECT_FALSE(IsWellFormed(&diff, updated_node_names)); +} + +using MutationNewNodeForTest = NewNode; + +TEST(MutationNewNodeTest, UpdateName) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, ""); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, UpdateOp) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateOp(&new_node, "Identity"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateOp(&new_node, ""); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, UpdateDevice) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateDevice(&new_node, "foo_device"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateDevice(&new_node, ""); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, AddOrUpdateRegularFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + // Bad index. + AddOrUpdateRegularFanin(&new_node, -1, {"a", 1}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + // Fanin at index 0 is missing. + AddOrUpdateRegularFanin(&new_node, 1, {"a", 1}); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + AddOrUpdateRegularFanin(&new_node, 0, {"b", 2}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + AddOrUpdateRegularFanin(&new_node, 2, {"c", 3}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + // Update inplace. + AddOrUpdateRegularFanin(&new_node, 1, {"d", 4}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + // Missing fanin. + AddOrUpdateRegularFanin(&new_node, 1, {"e", 5}); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + + // Self loop. + AddOrUpdateRegularFanin(&new_node, 1, {"new", 6}); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + + AddOrUpdateRegularFanin(&new_node, 1, {"d", 4}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, RemoveRegularFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + AddOrUpdateRegularFanin(&new_node, 0, {"a", 1}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + AddOrUpdateRegularFanin(&new_node, 1, {"b", 2}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + AddOrUpdateRegularFanin(&new_node, 2, {"c", 3}); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + RemoveRegularFanin(&new_node, 3); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + RemoveRegularFanin(&new_node, 2); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + RemoveRegularFanin(&new_node, 0); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + RemoveRegularFanin(&new_node, 1); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, AddControllingFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + AddControllingFanin(&new_node, "a"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + // Missing fanin. + AddControllingFanin(&new_node, "e"); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + + // Self loop. + AddControllingFanin(&new_node, "new"); + EXPECT_FALSE(IsWellFormed(&new_node, updated_node_names)); + + RemoveControllingFanin(&new_node, "e"); + RemoveControllingFanin(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, RemoveControllingFanin) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + UpdateName(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + AddControllingFanin(&new_node, "a"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + RemoveControllingFanin(&new_node, "e"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + RemoveControllingFanin(&new_node, "new"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + + RemoveControllingFanin(&new_node, "a"); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, AddOrUpdateAttribute) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + string attr_name = "attr_name"; + AttrValue attr_1; + attr_1.set_i(8); + AddOrUpdateAttribute(&new_node, attr_name, attr_1); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + AttrValue attr_2; + attr_2.set_f(2.0f); + AddOrUpdateAttribute(&new_node, attr_name, attr_2); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +TEST(MutationNewNodeTest, RemoveAttribute) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + auto updated_node_names = GetUpdatedNodeNames(&graph_view); + + MutationNewNodeForTest new_node(&graph_view, {}); + + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + string attr_name = "attr_name"; + AttrValue attr_1; + attr_1.set_i(8); + AddOrUpdateAttribute(&new_node, attr_name, attr_1); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + RemoveAttribute(&new_node, attr_name); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); + RemoveAttribute(&new_node, attr_name); + EXPECT_TRUE(IsWellFormed(&new_node, updated_node_names)); +} + +} // namespace +} // namespace internal +} // namespace utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/graph_view_test.cc b/tensorflow/core/grappler/utils/graph_view_test.cc new file mode 100644 index 00000000000..4dd8ecc0254 --- /dev/null +++ b/tensorflow/core/grappler/utils/graph_view_test.cc @@ -0,0 +1,2545 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/graph_view.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/substitute.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/benchmark_testlib.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace grappler { +namespace utils { +namespace { + +using ::tensorflow::test::function::GDef; +using ::tensorflow::test::function::NDef; + +constexpr char kNoOp[] = "NoOp"; + +GraphDef SimpleTestGraph() { + return GDef({NDef("a", kNoOp, {"b:2", "d:3", "b:2", "d:3", "^c"}), + NDef("b", kNoOp, {"d:2", "c:5", "^c"}), + NDef("c", kNoOp, {"^d", "^d"}), NDef("d", kNoOp, {})}, + /*funcs=*/{}); +} + +template +class TypedGraphViewTest : public ::testing::Test { + public: + const string type_as_string_ = + std::is_same::value ? "GraphView" : "MutableGraphView"; +}; + +using GraphViewTypes = ::testing::Types; +TYPED_TEST_SUITE(TypedGraphViewTest, GraphViewTypes); + +TYPED_TEST(TypedGraphViewTest, GraphWithDuplicateNodeNames) { + GraphDef graph = + GDef({NDef("a", kNoOp, {}), NDef("a", kNoOp, {})}, /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + absl::Substitute( + "$0::$0 error: graph has multiple nodes with the name 'a'.", + this->type_as_string_)); +} + +TYPED_TEST(TypedGraphViewTest, GraphWithMissingFanins) { + GraphDef graph = GDef({NDef("a", kNoOp, {"b:3"})}, /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + absl::Substitute("$0::$0 error: node 'a' has missing fanin 'b:3'.", + this->type_as_string_)); +} + +TYPED_TEST(TypedGraphViewTest, GraphWithSelfCycles) { + GraphDef graph = GDef({NDef("a", kNoOp, {"a:4"})}, /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ( + s.error_message(), + absl::Substitute("$0::$0 error: node 'a' has self cycle fanin 'a:4'.", + this->type_as_string_)); +} + +TYPED_TEST(TypedGraphViewTest, GraphWithMisorderedFanins) { + GraphDef graph = GDef({NDef("a", kNoOp, {"^b", "b:4"}), NDef("b", kNoOp, {})}, + /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + absl::Substitute("$0::$0 error: node 'a' has regular fanin 'b:4' " + "after controlling fanins.", + this->type_as_string_)); +} + +TYPED_TEST(TypedGraphViewTest, GetNodeWithIndex) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + const int num_nodes = graph_view.NumNodes(); + ASSERT_EQ(graph_view.NumNodes(), graph.node_size()); + for (int i = 0; i < num_nodes; ++i) { + const auto* node = graph_view.GetNode(i); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->node(), graph.mutable_node(i)); + } + + const auto* bad_node = graph_view.GetNode(-1); + ASSERT_EQ(bad_node, nullptr); + bad_node = graph_view.GetNode(num_nodes); + ASSERT_EQ(bad_node, nullptr); +} + +TYPED_TEST(TypedGraphViewTest, GetNodeWithName) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + std::vector node_names = {"a", "b", "c", "d"}; + for (int i = 0; i < node_names.size(); ++i) { + const string& node_name = node_names[i]; + const auto* node = graph_view.GetNode(node_name); + ASSERT_NE(node, nullptr); + EXPECT_EQ(node->node(), graph.mutable_node(i)); + } + + // Missing node. + const auto* bad_node = graph_view.GetNode("e"); + ASSERT_EQ(bad_node, nullptr); +} + +TYPED_TEST(TypedGraphViewTest, GetNodes) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + const auto& nodes = graph_view.GetNodes(); + const int num_nodes = nodes.size(); + EXPECT_EQ(num_nodes, 4); + + ASSERT_EQ(num_nodes, graph.node_size()); + for (int i = 0; i < num_nodes; ++i) { + EXPECT_EQ(nodes[i].node(), graph.mutable_node(i)); + } +} + +TYPED_TEST(TypedGraphViewTest, HasNode) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + for (const string& node_name : {"a", "b", "c", "d"}) { + EXPECT_TRUE(graph_view.HasNode(node_name)); + } + + // Missing node. + EXPECT_FALSE(graph_view.HasNode("e")); +} + +TYPED_TEST(TypedGraphViewTest, NumNodes) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + EXPECT_EQ(graph_view.NumNodes(), 4); +} + +TYPED_TEST(TypedGraphViewTest, NumNodesEmptyGraph) { + GraphDef graph; + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + EXPECT_EQ(graph_view.NumNodes(), 0); +} + +TEST(MutableGraphViewTest, DedupControlDependencies) { + GraphDef graph = GDef( + {NDef("a", kNoOp, {}), NDef("b", kNoOp, {}), NDef("c", kNoOp, {}), + NDef("d", kNoOp, {"a:2", "b:1", "^c", "^c", "^a", "^a", "^b", "^c"})}, + /*funcs=*/{}); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + EXPECT_EQ(graph_view.NumNodes(), 4); + + const auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + const auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + const auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + const auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + EXPECT_EQ(d_node->NumRegularFanins(), 2); + ASSERT_NE(d_node->node(), nullptr); + ASSERT_EQ(d_node->node()->input_size(), 5); + EXPECT_EQ(d_node->node()->input(0), "a:2"); + EXPECT_EQ(d_node->node()->input(1), "b:1"); + EXPECT_EQ(d_node->node()->input(2), "^c"); + EXPECT_EQ(d_node->node()->input(3), "^b"); + EXPECT_EQ(d_node->node()->input(4), "^a"); + ASSERT_EQ(d_node->NumControllingFanins(), 3); + const auto& d_control_fanins = d_node->GetControllingFanins(); + ASSERT_EQ(d_control_fanins.size(), 3); + ASSERT_NE(d_control_fanins[0].node_view(), nullptr); + EXPECT_EQ(d_control_fanins[0].node_view()->GetName(), "c"); + ASSERT_NE(d_control_fanins[1].node_view(), nullptr); + EXPECT_EQ(d_control_fanins[1].node_view()->GetName(), "b"); + ASSERT_NE(d_control_fanins[2].node_view(), nullptr); + EXPECT_EQ(d_control_fanins[2].node_view()->GetName(), "a"); +} + +template +class TypedNodeViewTest : public ::testing::Test {}; +TYPED_TEST_SUITE(TypedNodeViewTest, GraphViewTypes); + +TYPED_TEST(TypedNodeViewTest, GetName) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + for (const NodeDef& node : graph.node()) { + const auto* node_view = graph_view.GetNode(node.name()); + ASSERT_NE(node_view, nullptr); + EXPECT_EQ(node_view->GetName(), node.name()); + EXPECT_EQ(node_view->GetName(), node_view->node()->name()); + } +} + +TYPED_TEST(TypedNodeViewTest, GetOp) { + GraphDef graph = GDef({NDef("a", "op_a", {}), NDef("b", "op_b", {}), + NDef("c", "op_c", {}), NDef("d", "op_d", {})}, + /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + const auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + EXPECT_EQ(a_node->GetOp(), "op_a"); + EXPECT_EQ(a_node->node()->op(), "op_a"); + const auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + EXPECT_EQ(b_node->GetOp(), "op_b"); + EXPECT_EQ(b_node->node()->op(), "op_b"); + const auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + EXPECT_EQ(c_node->GetOp(), "op_c"); + EXPECT_EQ(c_node->node()->op(), "op_c"); + const auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + EXPECT_EQ(d_node->GetOp(), "op_d"); + EXPECT_EQ(d_node->node()->op(), "op_d"); +} + +TYPED_TEST(TypedNodeViewTest, GetDevice) { + GraphDef graph = GDef( + {NDef("a", "", {}, {}, "device_a"), NDef("b", "", {}, {}, "device_b"), + NDef("c", "", {}, {}, "device_c"), NDef("d", "", {}, {})}, + /*funcs=*/{}); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + const auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + EXPECT_EQ(a_node->GetDevice(), "device_a"); + EXPECT_EQ(a_node->node()->device(), "device_a"); + const auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + EXPECT_EQ(b_node->GetDevice(), "device_b"); + EXPECT_EQ(b_node->node()->device(), "device_b"); + const auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + EXPECT_EQ(c_node->GetDevice(), "device_c"); + EXPECT_EQ(c_node->node()->device(), "device_c"); + const auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + EXPECT_EQ(d_node->GetDevice(), ""); + EXPECT_EQ(d_node->node()->device(), ""); +} + +template +class TypedFaninTest : public ::testing::Test {}; +using FaninTypes = + ::testing::Types, + std::pair>; +TYPED_TEST_SUITE(TypedFaninTest, FaninTypes); + +TYPED_TEST(TypedFaninTest, GetRegularFanins) { + using FanoutViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& a_fanins = a_node->GetRegularFanins(); + ASSERT_EQ(a_fanins.size(), 4); + EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, b_node->node_index(), 2)); + EXPECT_EQ(a_fanins[1], FanoutViewType(&graph_view, d_node->node_index(), 3)); + EXPECT_EQ(a_fanins[2], FanoutViewType(&graph_view, b_node->node_index(), 2)); + EXPECT_EQ(a_fanins[3], FanoutViewType(&graph_view, d_node->node_index(), 3)); + + const auto& d_fanins = d_node->GetRegularFanins(); + EXPECT_EQ(d_fanins.size(), 0); +} + +TYPED_TEST(TypedFaninTest, GetRegularFanin) { + using FanoutViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& a_fanin_0 = a_node->GetRegularFanin(0); + EXPECT_EQ(a_fanin_0, FanoutViewType(&graph_view, b_node->node_index(), 2)); + const auto& a_fanin_1 = a_node->GetRegularFanin(1); + EXPECT_EQ(a_fanin_1, FanoutViewType(&graph_view, d_node->node_index(), 3)); + const auto& a_fanin_2 = a_node->GetRegularFanin(2); + EXPECT_EQ(a_fanin_2, FanoutViewType(&graph_view, b_node->node_index(), 2)); + const auto& a_fanin_3 = a_node->GetRegularFanin(3); + EXPECT_EQ(a_fanin_3, FanoutViewType(&graph_view, d_node->node_index(), 3)); + + // Out of bounds. + const FanoutViewType missing_fanin; + EXPECT_EQ(missing_fanin, FanoutViewType(nullptr, -1, -2)); + EXPECT_EQ(missing_fanin.node_view(), nullptr); + const auto& a_fanin_4 = a_node->GetRegularFanin(4); + EXPECT_EQ(a_fanin_4, missing_fanin); + const auto& a_fanin_5 = a_node->GetRegularFanin(5); + EXPECT_EQ(a_fanin_5, missing_fanin); + const auto& a_fanin_control = a_node->GetRegularFanin(Graph::kControlSlot); + EXPECT_EQ(a_fanin_control, missing_fanin); + const auto& a_fanin_bad = a_node->GetRegularFanin(-2); + EXPECT_EQ(a_fanin_bad, missing_fanin); +} + +TYPED_TEST(TypedFaninTest, GetControllingFanins) { + using FanoutViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& a_fanins = a_node->GetControllingFanins(); + ASSERT_EQ(a_fanins.size(), 1); + EXPECT_EQ(a_fanins[0], FanoutViewType(&graph_view, c_node->node_index(), + Graph::kControlSlot)); + + const auto& c_fanins = c_node->GetControllingFanins(); + FanoutViewType d_control_fanin(&graph_view, d_node->node_index(), + Graph::kControlSlot); + if (std::is_same::value) { + ASSERT_EQ(c_fanins.size(), 2); + EXPECT_EQ(c_fanins[0], d_control_fanin); + EXPECT_EQ(c_fanins[1], d_control_fanin); + } else { // MutableGraphView will dedup control dependency. + ASSERT_EQ(c_fanins.size(), 1); + EXPECT_EQ(c_fanins[0], d_control_fanin); + } + + const auto& d_fanins = d_node->GetControllingFanins(); + EXPECT_EQ(d_fanins.size(), 0); +} + +template +class TypedFanoutTest : public ::testing::Test {}; +using FanoutTypes = + ::testing::Types, + std::pair>; +TYPED_TEST_SUITE(TypedFanoutTest, FanoutTypes); + +TYPED_TEST(TypedFanoutTest, GetRegularFanouts) { + using FaninViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& d_fanouts = d_node->GetRegularFanouts(); + ASSERT_EQ(d_fanouts.size(), 4); + for (int i = 0; i < d_fanouts.size(); ++i) { + if (i == 2) { + ASSERT_EQ(d_fanouts[i].size(), 1); + EXPECT_EQ(d_fanouts[i][0], + FaninViewType(&graph_view, b_node->node_index(), 0)); + } else if (i == 3) { + ASSERT_EQ(d_fanouts[i].size(), 2); + absl::flat_hash_set fanouts(d_fanouts[i].begin(), + d_fanouts[i].end()); + EXPECT_TRUE(fanouts.contains( + FaninViewType(&graph_view, a_node->node_index(), 1))); + EXPECT_TRUE(fanouts.contains( + FaninViewType(&graph_view, a_node->node_index(), 3))); + } else { + EXPECT_EQ(d_fanouts[i].size(), 0); + } + } + + const auto& a_fanouts = a_node->GetRegularFanouts(); + EXPECT_EQ(a_fanouts.size(), 0); +} + +TYPED_TEST(TypedFanoutTest, GetRegularFanout) { + using FaninViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& d_fanouts_2 = d_node->GetRegularFanout(2); + ASSERT_EQ(d_fanouts_2.size(), 1); + EXPECT_EQ(d_fanouts_2.at(0), + FaninViewType(&graph_view, b_node->node_index(), 0)); + + const auto& d_fanouts_3 = d_node->GetRegularFanout(3); + EXPECT_EQ(d_fanouts_3.size(), 2); + absl::flat_hash_set d_fanouts_3_set(d_fanouts_3.begin(), + d_fanouts_3.end()); + EXPECT_TRUE(d_fanouts_3_set.contains( + FaninViewType(&graph_view, a_node->node_index(), 1))); + EXPECT_TRUE(d_fanouts_3_set.contains( + FaninViewType(&graph_view, a_node->node_index(), 3))); + + // Invalid or empty. + const std::vector no_fanouts; + EXPECT_EQ(d_node->GetRegularFanout(-2), no_fanouts); + EXPECT_EQ(d_node->GetRegularFanout(Graph::kControlSlot), no_fanouts); + EXPECT_EQ(d_node->GetRegularFanout(0), no_fanouts); + EXPECT_EQ(d_node->GetRegularFanout(1), no_fanouts); + EXPECT_EQ(d_node->GetRegularFanout(4), no_fanouts); + EXPECT_EQ(d_node->GetRegularFanout(5), no_fanouts); +} + +TYPED_TEST(TypedFanoutTest, GetControlledFanouts) { + using FaninViewType = typename TypeParam::first_type; + using GraphViewType = typename TypeParam::second_type; + + GraphDef graph = SimpleTestGraph(); + + Status s; + GraphViewType graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + const auto& c_fanouts = c_node->GetControlledFanouts(); + EXPECT_EQ(c_fanouts.size(), 2); + absl::flat_hash_set c_fanouts_set(c_fanouts.begin(), + c_fanouts.end()); + EXPECT_TRUE(c_fanouts_set.contains( + FaninViewType(&graph_view, b_node->node_index(), Graph::kControlSlot))); + EXPECT_TRUE(c_fanouts_set.contains( + FaninViewType(&graph_view, a_node->node_index(), Graph::kControlSlot))); + + const auto& d_fanouts = d_node->GetControlledFanouts(); + FaninViewType c_control_fanout(&graph_view, c_node->node_index(), + Graph::kControlSlot); + if (std::is_same::value) { + ASSERT_EQ(d_fanouts.size(), 2); + EXPECT_EQ(d_fanouts[0], c_control_fanout); + EXPECT_EQ(d_fanouts[1], c_control_fanout); + } else { // MutableGraphView will dedup control dependency. + ASSERT_EQ(d_fanouts.size(), 1); + EXPECT_EQ(d_fanouts[0], c_control_fanout); + } + + const auto& a_fanouts = a_node->GetControlledFanouts(); + EXPECT_EQ(a_fanouts.size(), 0); +} + +TYPED_TEST(TypedNodeViewTest, NumRegularFanins) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + EXPECT_EQ(a_node->NumRegularFanins(), 4); + EXPECT_EQ(b_node->NumRegularFanins(), 2); + EXPECT_EQ(c_node->NumRegularFanins(), 0); + EXPECT_EQ(d_node->NumRegularFanins(), 0); +} + +TYPED_TEST(TypedNodeViewTest, NumControllingFanins) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + EXPECT_EQ(a_node->NumControllingFanins(), 1); + EXPECT_EQ(b_node->NumControllingFanins(), 1); + if (std::is_same::value) { + EXPECT_EQ(c_node->NumControllingFanins(), 2); + } else { + EXPECT_EQ(c_node->NumControllingFanins(), 1); + } + EXPECT_EQ(d_node->NumControllingFanins(), 0); +} + +TYPED_TEST(TypedNodeViewTest, NumRegularFanouts) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + EXPECT_EQ(a_node->NumRegularFanouts(), 0); + EXPECT_EQ(b_node->NumRegularFanouts(), 2); + EXPECT_EQ(c_node->NumRegularFanouts(), 1); + EXPECT_EQ(d_node->NumRegularFanouts(), 3); +} + +TYPED_TEST(TypedNodeViewTest, NumControlledFanouts) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + EXPECT_EQ(a_node->NumControlledFanouts(), 0); + EXPECT_EQ(b_node->NumControlledFanouts(), 0); + EXPECT_EQ(c_node->NumControlledFanouts(), 2); + if (std::is_same::value) { + EXPECT_EQ(d_node->NumControlledFanouts(), 2); + } else { + EXPECT_EQ(d_node->NumControlledFanouts(), 1); + } +} + +TYPED_TEST(TypedNodeViewTest, HasFanin) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + // Existing regular fanin. + EXPECT_TRUE(a_node->HasFanin({&graph_view, b_node->node_index(), 2})); + // Missing regular fanin. + EXPECT_FALSE(a_node->HasFanin({&graph_view, c_node->node_index(), 4})); + // Existing controlling fanin. + EXPECT_TRUE(a_node->HasFanin( + {&graph_view, c_node->node_index(), Graph::kControlSlot})); + // Missing controlling fanin. + EXPECT_FALSE(a_node->HasFanin( + {&graph_view, b_node->node_index(), Graph::kControlSlot})); + // Bad fanins. + EXPECT_FALSE(a_node->HasFanin({&graph_view, a_node->node_index(), 0})); + EXPECT_FALSE(a_node->HasFanin( + {&graph_view, b_node->node_index(), internal::kMissingSlot})); +} + +TYPED_TEST(TypedNodeViewTest, HasFanout) { + GraphDef graph = SimpleTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + auto* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + // Existing regular fanout. + EXPECT_TRUE(b_node->HasFanout({&graph_view, a_node->node_index(), 2})); + // Missing regular fanout. + EXPECT_FALSE(b_node->HasFanout({&graph_view, a_node->node_index(), 1})); + // Existing controlled fanout. + EXPECT_TRUE(d_node->HasFanout( + {&graph_view, c_node->node_index(), Graph::kControlSlot})); + // Missing controlled fanout. + EXPECT_FALSE(d_node->HasFanout( + {&graph_view, a_node->node_index(), Graph::kControlSlot})); + // Bad fanouts. + EXPECT_FALSE(d_node->HasFanout({&graph_view, d_node->node_index(), 0})); + EXPECT_FALSE(a_node->HasFanout({&graph_view, b_node->node_index(), 0})); + EXPECT_FALSE(a_node->HasFanout({&graph_view, 4, 0})); + EXPECT_FALSE(d_node->HasFanout( + {&graph_view, b_node->node_index(), internal::kMissingSlot})); +} + +GraphDef SimpleAttrTestGraph() { + return GDef({NDef("a", kNoOp, {}), NDef("b", kNoOp, {}, {{"attr", 1}}), + NDef("c", kNoOp, {}, {{"attr_1", "a"}, {"attr_2", 2.0f}})}, + /*funcs=*/{}); +} + +TYPED_TEST(TypedNodeViewTest, GetAttr) { + GraphDef graph = SimpleAttrTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + EXPECT_EQ(c_node->GetAttr("attr_1")->s(), "a"); +} + +TYPED_TEST(TypedNodeViewTest, GetAttrs) { + GraphDef graph = SimpleAttrTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + const auto& actual_attrs = c_node->GetAttrs(); + EXPECT_EQ(actual_attrs.size(), 2); + const auto* attr_1 = actual_attrs.Find("attr_1"); + EXPECT_NE(attr_1, nullptr); + EXPECT_EQ(attr_1->s(), "a"); + const auto* attr_2 = actual_attrs.Find("attr_2"); + EXPECT_NE(attr_2, nullptr); + EXPECT_EQ(attr_2->f(), 2.0f); +} + +TYPED_TEST(TypedNodeViewTest, NumAttrs) { + GraphDef graph = SimpleAttrTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + auto* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + EXPECT_EQ(a_node->NumAttrs(), 0); + EXPECT_EQ(b_node->NumAttrs(), 1); + EXPECT_EQ(c_node->NumAttrs(), 2); +} + +TYPED_TEST(TypedNodeViewTest, HasAttr) { + GraphDef graph = SimpleAttrTestGraph(); + + Status s; + TypeParam graph_view(&graph, &s); + TF_ASSERT_OK(s); + + auto* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + EXPECT_TRUE(c_node->HasAttr("attr_1")); + EXPECT_FALSE(c_node->HasAttr("attr")); +} + +class MutationTest : public GrapplerTest { + public: + void CompareGraphViewWithGraph(MutableGraphView* graph_view, + const GraphDef& expected_graph) { + Status s; + GraphView expected_graph_view(&expected_graph, &s); + TF_ASSERT_OK(s); + + EXPECT_EQ(graph_view->NumNodes(), expected_graph_view.NumNodes()); + + for (const NodeView& expected_node_view : expected_graph_view.GetNodes()) { + const string& node_name = expected_node_view.GetName(); + MutableNodeView* node_view = graph_view->GetNode(node_name); + ASSERT_NE(node_view, nullptr); + + EXPECT_EQ(node_view->GetName(), expected_node_view.GetName()); + + EXPECT_EQ(node_view->GetOp(), expected_node_view.GetOp()); + + EXPECT_EQ(node_view->GetDevice(), expected_node_view.GetDevice()); + + const int actual_num_fanins = node_view->node()->input_size(); + EXPECT_EQ(actual_num_fanins, expected_node_view.node()->input_size()); + + const int expected_num_regular_fanins = + expected_node_view.NumRegularFanins(); + bool same_num_regular_fanins = + node_view->NumRegularFanins() == expected_num_regular_fanins; + EXPECT_TRUE(same_num_regular_fanins); + for (int i = 0; i < expected_num_regular_fanins; ++i) { + const auto& expected_fanin = expected_node_view.GetRegularFanin(i); + + auto* actual_fanin_node = + graph_view->GetNode(expected_fanin.node_view()->GetName()); + ASSERT_NE(actual_fanin_node, nullptr); + EXPECT_TRUE( + node_view->HasFanin({actual_fanin_node, expected_fanin.index()})); + if (i < node_view->NumRegularFanins()) { + auto& actual_fanin = node_view->GetRegularFanin(i); + EXPECT_EQ(actual_fanin, MutableFanoutView(actual_fanin_node, + expected_fanin.index())); + EXPECT_EQ(actual_fanin.node_index(), + actual_fanin.node_view()->node_index()); + } + } + + if (same_num_regular_fanins) { + for (int i = 0; i < expected_num_regular_fanins; ++i) { + const auto& fanin = node_view->GetRegularFanin(i); + EXPECT_EQ(ParseTensorName(node_view->node()->input(i)), + TensorId(fanin.node_view()->GetName(), fanin.index())); + } + } + + const int expected_num_controlling_fanins = + expected_node_view.NumControllingFanins(); + bool same_num_controlling_fanins = + node_view->NumControllingFanins() == expected_num_controlling_fanins; + EXPECT_TRUE(same_num_controlling_fanins); + for (int i = 0; i < expected_num_controlling_fanins; ++i) { + auto& expected_fanin = expected_node_view.GetControllingFanins()[i]; + + auto* actual_fanin_node = + graph_view->GetNode(expected_fanin.node_view()->GetName()); + ASSERT_NE(actual_fanin_node, nullptr); + MutableFanoutView actual_fanin(actual_fanin_node, + expected_fanin.index()); + EXPECT_TRUE(node_view->HasFanin(actual_fanin)); + + int found = 0; + for (const auto& actual_fanin : node_view->GetControllingFanins()) { + if (actual_fanin.index() == expected_fanin.index() && + actual_fanin.node_view()->GetName() == + expected_fanin.node_view()->GetName()) { + EXPECT_EQ(actual_fanin.node_index(), + actual_fanin.node_view()->node_index()); + ++found; + } + } + EXPECT_EQ(found, 1); + } + + if (same_num_controlling_fanins && same_num_regular_fanins) { + for (int i = 0; i < expected_num_controlling_fanins; ++i) { + const auto& fanin = node_view->GetControllingFanins()[i]; + EXPECT_EQ(ParseTensorName(node_view->node()->input( + i + expected_num_regular_fanins)), + TensorId(fanin.node_view()->GetName(), fanin.index())); + } + } + + EXPECT_EQ(node_view->NumRegularFanouts(), + expected_node_view.NumRegularFanouts()); + const int num_output_ports = + expected_node_view.GetRegularFanouts().size(); + ASSERT_EQ(node_view->GetRegularFanouts().size(), num_output_ports); + for (int i = 0; i < num_output_ports; ++i) { + auto& expected_fanouts_at_port_i = node_view->GetRegularFanouts()[i]; + const int num_fanouts_at_port = expected_fanouts_at_port_i.size(); + + auto& actual_fanouts_at_port_i = node_view->GetRegularFanouts()[i]; + EXPECT_EQ(actual_fanouts_at_port_i.size(), num_fanouts_at_port); + + for (int j = 0; j < num_fanouts_at_port; ++j) { + auto& expected_fanout = expected_fanouts_at_port_i[j]; + + auto* actual_fanout_node = + graph_view->GetNode(expected_fanout.node_view()->GetName()); + + ASSERT_NE(actual_fanout_node, nullptr); + MutableFaninView actual_fanout(actual_fanout_node, + expected_fanout.index()); + EXPECT_TRUE(node_view->HasFanout(actual_fanout)); + + int found = 0; + for (const auto& fanout : actual_fanouts_at_port_i) { + if (fanout.index() == expected_fanout.index() && + fanout.node_view()->GetName() == + expected_fanout.node_view()->GetName()) { + EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index()); + ++found; + } + } + EXPECT_EQ(found, 1); + } + } + + const int num_controlled_fanouts = + expected_node_view.NumControlledFanouts(); + EXPECT_EQ(node_view->NumControlledFanouts(), num_controlled_fanouts); + for (int i = 0; i < num_controlled_fanouts; ++i) { + const auto& expected_fanout = + expected_node_view.GetControlledFanouts()[i]; + + auto* actual_fanout_node = + graph_view->GetNode(expected_fanout.node_view()->GetName()); + ASSERT_NE(actual_fanout_node, nullptr); + MutableFaninView actual_fanout(actual_fanout_node, + expected_fanout.index()); + EXPECT_TRUE(node_view->HasFanout(actual_fanout)); + + int found = 0; + for (const auto& fanout : node_view->GetControlledFanouts()) { + if (fanout.index() == expected_fanout.index() && + fanout.node_view()->GetName() == + expected_fanout.node_view()->GetName()) { + EXPECT_EQ(fanout.node_index(), fanout.node_view()->node_index()); + ++found; + } + } + EXPECT_EQ(found, 1); + } + + EXPECT_EQ(node_view->NumAttrs(), expected_node_view.NumAttrs()); + for (const auto& expected_attr : expected_node_view.GetAttrs()) { + auto* attr = node_view->GetAttr(expected_attr.first); + EXPECT_TRUE(AreAttrValuesEqual(*attr, expected_attr.second)); + } + } + CompareGraphs(*graph_view->graph(), expected_graph); + } +}; + +constexpr char kDeviceCPU0[] = "/device:CPU:0"; +constexpr char kDeviceGPU0[] = "/device:GPU:0"; + +GraphDef SimpleTestGraphForMutation() { + return GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)}, + /*funcs=*/{}); +} + +TEST_F(MutationTest, AddNewNode) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef empty_node; + mutation->AddNode(std::move(empty_node), &s); + TF_EXPECT_OK(s); + s = errors::Internal("error"); + + NodeDef valid_node = + NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo"); + mutation->AddNode(std::move(valid_node), &s); + TF_EXPECT_OK(s); + + NodeDef bad_node_1 = + NDef("bad", "IdentityN", {"^b", "a:1"}, {{"N", 1}}, "foo"); + mutation->AddNode(std::move(bad_node_1), &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Mutation::AddNode error: node 'bad' has regular fanin 'a:1' after " + "controlling fanins."); + + NodeDef bad_node_2 = NDef("bad", "IdentityN", {"bad:1"}, {}, "foo"); + mutation->AddNode(std::move(bad_node_2), &s); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "Mutation::AddNode error: node 'bad' has self cycle fanin " + "'bad:1'."); + + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, NewNodeBadFaninsAfterAdd) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef valid_node = + NDef("valid", "IdentityN", {"a:1", "^b"}, {{"N", 1}}, "foo"); + MutationNewNode new_node = mutation->AddNode(std::move(valid_node), &s); + + mutation->AddOrUpdateRegularFanin(new_node, 1, {"valid", 2}); + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: new node 'valid' is ill-formed."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, NewNodesConflictingNames) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node_1 = NDef("a", "", {}); + mutation->AddNode(std::move(new_node_1), &s); + TF_EXPECT_OK(s); + + NodeDef new_node_2 = NDef("a", "", {}); + mutation->AddNode(std::move(new_node_2), &s); + TF_EXPECT_OK(s); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: multiple nodes with the name: 'a' exists in " + "Mutation."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, UpdateNodeAndAddSelfLoop) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->AddControllingFanin(d_node, "d"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: inplace updated node 'd' is ill-formed."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, RenameNodeAndAddSelfLoop) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->UpdateNodeName(d_node, "e"); + mutation->AddControllingFanin(d_node, "e"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: renamed updated node 'e' ('d') is ill-formed."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, ExistingNodesConflictingNames) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + mutation->UpdateNodeName(a_node, "b"); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + mutation->UpdateNodeOp(b_node, "Identity"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: multiple nodes with the name: 'b' exists in " + "Mutation."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, NewAndExistingNodesConflictingNames) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node = NDef("a", "", {}); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + mutation->UpdateNodeDevice(a_node, "foo"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: multiple nodes with the name: 'a' exists in " + "Mutation."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, NewAndExistingRenamedNodesConflictingNames) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node = NDef("e", "", {}); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->UpdateNodeName(d_node, "e"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: multiple nodes with the name: 'e' exists in " + "Mutation."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, RemoveNodesWithFanouts) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + mutation->RemoveNode(b_node); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'b'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->RemoveNode(d_node); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, SwapNodeNamesWithCycle) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->UpdateNodeName(d_node, "b"); + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + mutation->UpdateNodeName(b_node, "d"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: renamed updated node 'b' ('d') is ill-formed."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + mutation->AddOrUpdateRegularFanin(d_node, 1, {"d", 3}); + mutation->RemoveControllingFanin(d_node, "b"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("d", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {"a:2", "d:3", "a:4", "^c"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, RenamedNodeWithFanouts) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + mutation->UpdateNodeName(a_node, "b"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'a'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + mutation->UpdateNodeName(a_node, "a"); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + mutation->UpdateNodeName(b_node, "e"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing " + "node 'b'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +TEST_F(MutationTest, RemoveExistingNodeAndReplaceWithNewNode) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->RemoveNode(d_node); + + NodeDef new_node = NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0), + NDef("d", kNoOp, {"c:8", "^a"}, {}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, UpdateNodeNameAndRemoveFanins) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->UpdateNodeName(d_node, "e"); + mutation->RemoveRegularFanin(d_node, 1); + mutation->RemoveRegularFanin(d_node, 2); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0), + NDef("e", kNoOp, {"a:2", "^c", "^b"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, UpdateNodeNameAndRemoveRegularFanout) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + mutation->UpdateNodeName(a_node, "e"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'a'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->RemoveRegularFanin(d_node, 2); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'a'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + mutation->AddOrUpdateRegularFanin(d_node, 0, {"b", 1}); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("e", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {}, {}, kDeviceCPU0), + NDef("c", kNoOp, {}, {}, kDeviceCPU0), + NDef("d", kNoOp, {"b:1", "b:3", "^c", "^b"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, UpdateNodeNameAndRemoveControlledFanout) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + mutation->UpdateNodeName(c_node, "e"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'c'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->UpdateNodeDevice(d_node, kDeviceGPU0); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + expected_error_msg = + "Mutation::Apply error: fanout 'd' exist for missing node 'c'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); + + mutation->RemoveControllingFanin(d_node, "c"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {}, kDeviceCPU0), + NDef("b", kNoOp, {}, {}, kDeviceCPU0), + NDef("e", kNoOp, {}, {}, kDeviceCPU0), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^b"}, + {{"attr_1", "a"}, {"attr_2", 2.0f}}, kDeviceGPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, EmptyMutation) { + GraphDef graph = SimpleTestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + TF_EXPECT_OK(mutation->Apply()); + CompareGraphViewWithGraph(&graph_view, SimpleTestGraphForMutation()); +} + +constexpr char kIdentity[] = "Identity"; +constexpr char kDeviceCPU1[] = "/device:CPU:1"; +constexpr char kDeviceGPU1[] = "/device:GPU:1"; + +GraphDef TestGraphForMutation() { + return GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)}, + /*funcs=*/{}); +} + +TEST_F(MutationTest, SwapNodeNamesWithNoCycle) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + MutableNodeView* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + + mutation->UpdateNodeName(b_node, "c"); + mutation->UpdateNodeName(c_node, "b"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("b", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, RemoveMultipleDependentNodes) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + mutation->RemoveNode(c_node); + mutation->RemoveNode(d_node); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +constexpr char kDeviceGPU2[] = "/device:GPU:2"; + +TEST_F(MutationTest, AddSimpleNewNode) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node = + NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("new_node", kIdentity, {}, {{"T", DT_INT64}}, kDeviceGPU2)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +constexpr char kDeviceGPU3[] = "/device:GPU:3"; + +TEST_F(MutationTest, AddAndUpdateNodesWithFanins) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node_1 = NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"}, + {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2); + mutation->AddNode(std::move(new_node_1), &s); + TF_EXPECT_OK(s); + + NodeDef new_node_2 = + NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"}, + {{"new_node_2_attr_1", 9}}, kDeviceGPU3); + mutation->AddNode(std::move(new_node_2), &s); + TF_EXPECT_OK(s); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + mutation->AddOrUpdateRegularFanin(d_node, 3, {"c", 6}); + mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 5}); + mutation->AddControllingFanin(d_node, "new_node_2"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, + {"a:2", "new_node_1:5", "a:4", "c:6", "^c", "^b", "^new_node_2"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("new_node_1", kNoOp, {"a:2", "d:5", "^b", "^c"}, + {{"new_node_1_attr_1", 5.0f}}, kDeviceGPU2), + NDef("new_node_2", kNoOp, {"a:3", "new_node_1:5", "^d", "^new_node_1"}, + {{"new_node_2_attr_1", 9}}, kDeviceGPU3)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, UpdateNodeNameToReplaceExistingNode) { + auto test_graph = []() { + return GDef( + {NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1), + NDef("c", kNoOp, {"b:4", "^a"}, {{"attr_c", "test"}}, kDeviceGPU2), + NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + + mutation->UpdateNodeName(b_node, "c"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {{"attr_a", 8}}, kDeviceCPU0), + NDef("c", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^a", "^c"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU3)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, NewNodeWithMutations) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + NodeDef new_node_def = NDef("node", kNoOp, {"a:2", "b:3", "^c"}, + {{"attr_1", 1}, {"attr_2", 2.0f}}, kDeviceGPU3); + MutationNewNode new_node = mutation->AddNode(std::move(new_node_def), &s); + TF_EXPECT_OK(s); + + mutation->AddControllingFanin(new_node, "a"); + mutation->RemoveControllingFanin(new_node, "c"); + mutation->AddOrUpdateRegularFanin(new_node, 0, {"b", 6}); + mutation->RemoveRegularFanin(new_node, 1); + mutation->UpdateNodeName(new_node, "new_node"); + mutation->UpdateNodeOp(new_node, kIdentity); + mutation->UpdateNodeDevice(new_node, kDeviceGPU2); + AttrValue attr_3; + attr_3.set_s("new_node_attr"); + mutation->AddOrUpdateNodeAttr(new_node, "attr_3", attr_3); + AttrValue attr_1; + attr_1.set_b(true); + mutation->AddOrUpdateNodeAttr(new_node, "attr_1", attr_1); + mutation->RemoveNodeAttr(new_node, "attr_2"); + AttrValue attr_4; + attr_4.set_type(DT_FLOAT); + mutation->AddOrUpdateNodeAttr(new_node, "T", attr_4); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("new_node", kIdentity, {"b:6", "^a"}, + {{"attr_1", true}, {"attr_3", "new_node_attr"}, {"T", DT_FLOAT}}, + kDeviceGPU2)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, UpdatedNodeWithNonFaninMutations) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->UpdateNodeName(d_node, "e"); + mutation->UpdateNodeOp(d_node, kIdentity); + mutation->UpdateNodeDevice(d_node, kDeviceGPU2); + AttrValue attr_d_1; + attr_d_1.set_b(false); + mutation->AddOrUpdateNodeAttr(d_node, "attr_d_1", attr_d_1); + AttrValue attr_e_3; + attr_e_3.set_s("test_string"); + mutation->AddOrUpdateNodeAttr(d_node, "attr_e_3", attr_e_3); + mutation->RemoveNodeAttr(d_node, "attr_d_2"); + AttrValue attr_e_4; + attr_e_4.set_type(DT_INT64); + mutation->AddOrUpdateNodeAttr(d_node, "T", attr_e_4); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("e", kIdentity, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", false}, {"attr_e_3", "test_string"}, {"T", DT_INT64}}, + kDeviceGPU2)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, Reset) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->UpdateNodeName(a_node, "e"); + mutation->AddNode({}, &s); + TF_EXPECT_OK(s); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + string expected_error_msg = + "Mutation::Apply error: fanout 'b' exist for missing node 'a'."; + EXPECT_EQ(s.error_message(), expected_error_msg); + CompareGraphViewWithGraph(&graph_view, TestGraphForMutation()); + + mutation->Reset(); + TF_EXPECT_OK(mutation->Apply()); + CompareGraphViewWithGraph(&graph_view, TestGraphForMutation()); +} + +TEST_F(MutationTest, RenameNodeAndAddNewNodeWithRenamedNodeOldName) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->UpdateNodeName(b_node, "e"); + + NodeDef new_node = + NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("e", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("b", kIdentity, {"c:2"}, {{"T", DT_INT64}}, kDeviceGPU3)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, ShiftNodesWithFanouts) { + auto test_graph = []() { + return GDef({NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^c", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, + kDeviceGPU0)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* c_node = graph_view.GetNode("c"); + ASSERT_NE(c_node, nullptr); + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->RemoveControllingFanin(d_node, "c"); + mutation->RemoveNode(c_node); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("d", kNoOp, {"a:2", "b:3", "a:4", "^a", "^b"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("b", kNoOp, {"a:2"}, {{"attr_b", 3.0f}}, kDeviceCPU0), + NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, RemoveFaninFanoutAndShiftFanout) { + auto test_graph = []() { + return GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2", "a:1"}, {}, kDeviceGPU1), + NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->RemoveRegularFanin(b_node, 1); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = + GDef({NDef("a", kNoOp, {}, {}, kDeviceGPU0), + NDef("b", kNoOp, {"a:2"}, {}, kDeviceGPU1), + NDef("c", kNoOp, {"a:1", "a:2"}, {}, kDeviceGPU2)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +TEST_F(MutationTest, ConsecutiveMutations) { + GraphDef graph = TestGraphForMutation(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* b_node = graph_view.GetNode("b"); + ASSERT_NE(b_node, nullptr); + MutableNodeView* d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + mutation->RemoveNode(b_node); + mutation->AddOrUpdateRegularFanin(d_node, 1, {"c", 5}); + mutation->RemoveControllingFanin(d_node, "b"); + + NodeDef new_node_1 = NDef("new_node_1", kIdentity, {"a:3", "d:5", "^d"}, + {{"T", DT_FLOAT}}, kDeviceGPU2); + MutationNewNode new_node_1_node = + mutation->AddNode(std::move(new_node_1), &s); + TF_EXPECT_OK(s); + + mutation->AddOrUpdateRegularFanin(new_node_1_node, 0, {"c", 5}); + mutation->RemoveRegularFanin(new_node_1_node, 1); + mutation->AddOrUpdateRegularFanin(new_node_1_node, 1, {"a", 6}); + mutation->AddControllingFanin(new_node_1_node, "a"); + mutation->RemoveControllingFanin(new_node_1_node, "d"); + + TF_EXPECT_OK(mutation->Apply()); + GraphDef expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, {"a:2", "c:5", "a:4", "^c"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}}, + kDeviceGPU2)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); + + d_node = graph_view.GetNode("d"); + ASSERT_NE(d_node, nullptr); + + mutation->AddOrUpdateRegularFanin(d_node, 3, {"new_node_2", 6}); + mutation->AddOrUpdateRegularFanin(d_node, 1, {"new_node_1", 8}); + mutation->AddControllingFanin(d_node, "new_node_2"); + mutation->AddControllingFanin(d_node, "a"); + mutation->RemoveControllingFanin(d_node, "c"); + + NodeDef new_node_2 = + NDef("new_node_2", kNoOp, {"c:4", "new_node_1:5", "^d", "^c"}); + MutationNewNode new_node_2_node = + mutation->AddNode(std::move(new_node_2), &s); + TF_EXPECT_OK(s); + + mutation->UpdateNodeDevice(new_node_2_node, kDeviceGPU3); + mutation->AddOrUpdateRegularFanin(new_node_2_node, 0, {"new_node_1", 4}); + mutation->RemoveRegularFanin(new_node_2_node, 1); + mutation->RemoveControllingFanin(new_node_2_node, "c"); + mutation->AddControllingFanin(new_node_2_node, "a"); + mutation->AddControllingFanin(new_node_2_node, "new_node_1"); + + TF_EXPECT_OK(mutation->Apply()); + expected_graph = GDef( + {NDef("a", kIdentity, {}, {{"attr_a", 8}, {"T", DT_FLOAT}}, kDeviceGPU0), + NDef("c", kNoOp, {"^a"}, {{"attr_c", "test"}}, kDeviceCPU1), + NDef("d", kNoOp, + {"a:2", "new_node_1:8", "a:4", "new_node_2:6", "^new_node_2", "^a"}, + {{"attr_d_1", "a"}, {"attr_d_2", 2.0f}}, kDeviceGPU1), + NDef("new_node_1", kIdentity, {"c:5", "a:6", "^a"}, {{"T", DT_FLOAT}}, + kDeviceGPU2), + NDef("new_node_2", kNoOp, {"new_node_1:4", "^d", "^a", "^new_node_1"}, + {}, kDeviceGPU3)}, + /*funcs=*/{}); + CompareGraphViewWithGraph(&graph_view, expected_graph); +} + +constexpr char kMatchingFiles[] = "MatchingFiles"; + +TEST_F(MutationTest, OpWithUnsupportedDevice) { + auto test_graph = []() { + return GDef({NDef("a", kMatchingFiles, {}, {}, kDeviceCPU0)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + // Unsupported device. + mutation->UpdateNodeDevice(a_node, kDeviceGPU1); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + CompareGraphViewWithGraph(&graph_view, test_graph()); + + mutation->Reset(); + + // New node with unsupported device. + NodeDef new_node = NDef("new_node", kMatchingFiles, {}, {}, kDeviceGPU2); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + CompareGraphViewWithGraph(&graph_view, test_graph()); +} + +TEST_F(MutationTest, OpMissingAttribute) { + auto test_graph = []() { + return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + // Remove necessary attribute. + mutation->RemoveNodeAttr(a_node, "T"); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + CompareGraphViewWithGraph(&graph_view, test_graph()); + + mutation->Reset(); + + // New node without necessary attribute. + NodeDef new_node = NDef("new_node", kIdentity, {}, {}, kDeviceGPU2); + mutation->AddNode(std::move(new_node), &s); + TF_EXPECT_OK(s); + + s = mutation->Apply(); + EXPECT_FALSE(s.ok()); + CompareGraphViewWithGraph(&graph_view, test_graph()); +} + +TEST_F(MutationTest, EmptyMutationUpdateIndexPersisting) { + auto test_graph = []() { + return GDef({NDef("a", kIdentity, {}, {{"T", DT_FLOAT}}, kDeviceGPU0)}, + /*funcs=*/{}); + }; + + GraphDef graph = test_graph(); + + Status s; + MutableGraphView graph_view(&graph, &s); + TF_ASSERT_OK(s); + + MutableNodeView* a_node = graph_view.GetNode("a"); + ASSERT_NE(a_node, nullptr); + + Mutation* mutation = graph_view.GetMutationBuilder(); + + // Empty MutableNodeViewDiff. + mutation->UpdateNodeName(a_node, "a"); + + TF_EXPECT_OK(mutation->Apply()); + CompareGraphViewWithGraph(&graph_view, test_graph()); + + mutation->Reset(); + + // Empty MutableNodeViewDiff, `update_index_` should not persist. + mutation->UpdateNodeName(a_node, "a"); + + TF_EXPECT_OK(mutation->Apply()); + CompareGraphViewWithGraph(&graph_view, test_graph()); +} + +#define RUN_NUM_NODE_NUM_EDGE_BENCHMARK(name) \ + BENCHMARK(name) \ + ->ArgPair(10, 2) \ + ->ArgPair(100, 2) \ + ->ArgPair(1000, 2) \ + ->ArgPair(10000, 2) \ + ->ArgPair(25000, 2) \ + ->ArgPair(50000, 2) \ + ->ArgPair(100000, 2) \ + ->ArgPair(10, 4) \ + ->ArgPair(100, 4) \ + ->ArgPair(1000, 4) \ + ->ArgPair(10000, 4) \ + ->ArgPair(25000, 4) \ + ->ArgPair(50000, 4) \ + ->ArgPair(100000, 4) \ + ->ArgPair(10, 8) \ + ->ArgPair(100, 8) \ + ->ArgPair(1000, 8) \ + ->ArgPair(10000, 8) \ + ->ArgPair(25000, 8) \ + ->ArgPair(50000, 8) \ + ->ArgPair(100000, 8) \ + ->ArgPair(10, 16) \ + ->ArgPair(100, 16) \ + ->ArgPair(1000, 16) \ + ->ArgPair(10000, 16) \ + ->ArgPair(25000, 16) \ + ->ArgPair(50000, 16) \ + ->ArgPair(100000, 16); + +template +static void BM_GraphViewTConstruction(int iters, int num_nodes, + int num_edges_per_node) { + testing::StopTiming(); + GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + Status s; + GraphViewT graph_view(&graph_def, &s); + } + testing::StopTiming(); +} + +static void BM_GraphViewConstruction(int iters, int num_nodes, + int num_edges_per_node) { + BM_GraphViewTConstruction(iters, num_nodes, num_edges_per_node); +} + +static void BM_MutableGraphViewConstruction(int iters, int num_nodes, + int num_edges_per_node) { + BM_GraphViewTConstruction(iters, num_nodes, + num_edges_per_node); +} + +RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction); +RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewConstruction); + +#define RUN_NUM_NODE_BENCHMARK(name) \ + BENCHMARK(name) \ + ->Arg(10) \ + ->Arg(100) \ + ->Arg(1000) \ + ->Arg(10000) \ + ->Arg(25000) \ + ->Arg(50000) \ + ->Arg(100000); + +template +static void BM_GraphViewTConstructionWithControlDependencies( + int iters, int num_fanins_fanouts) { + testing::StopTiming(); + GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins_fanouts, num_fanins_fanouts, + num_fanins_fanouts, num_fanins_fanouts, + /*fanout_unique_index=*/true); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + Status s; + GraphViewT graph_view(&graph_def, &s); + } + testing::StopTiming(); +} + +static void BM_GraphViewConstructionWithControlDependencies( + int iters, int num_fanins_fanouts) { + BM_GraphViewTConstructionWithControlDependencies( + iters, num_fanins_fanouts); +} + +static void BM_MutableGraphViewConstructionWithControlDependencies( + int iters, int num_fanins_fanouts) { + BM_GraphViewTConstructionWithControlDependencies( + iters, num_fanins_fanouts); +} + +RUN_NUM_NODE_BENCHMARK(BM_GraphViewConstructionWithControlDependencies); +RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewConstructionWithControlDependencies); + +template +static void BM_GraphViewTGetNode(int iters, int num_nodes) { + testing::StopTiming(); + GraphDef graph_def = + test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + graph_view.GetNode("out"); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetNode(int iters, int num_nodes) { + BM_GraphViewTGetNode(iters, num_nodes); +} + +static void BM_MutableGraphViewGetNode(int iters, int num_nodes) { + BM_GraphViewTGetNode(iters, num_nodes); +} + +RUN_NUM_NODE_BENCHMARK(BM_GraphViewGetNode); +RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewGetNode); + +#define RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(name) \ + BENCHMARK(name) \ + ->ArgPair(10, 10) \ + ->ArgPair(10, 100) \ + ->ArgPair(10, 1000) \ + ->ArgPair(10, 10000) \ + ->ArgPair(10, 100000) \ + ->ArgPair(100, 10) \ + ->ArgPair(100, 100) \ + ->ArgPair(100, 1000) \ + ->ArgPair(100, 10000) \ + ->ArgPair(100, 100000) \ + ->ArgPair(1000, 10) \ + ->ArgPair(1000, 100) \ + ->ArgPair(1000, 1000) \ + ->ArgPair(1000, 10000) \ + ->ArgPair(1000, 100000) \ + ->ArgPair(10000, 10) \ + ->ArgPair(10000, 100) \ + ->ArgPair(10000, 1000) \ + ->ArgPair(10000, 10000) \ + ->ArgPair(10000, 100000) \ + ->ArgPair(100000, 10) \ + ->ArgPair(100000, 100) \ + ->ArgPair(100000, 1000) \ + ->ArgPair(100000, 10000) \ + ->ArgPair(100000, 100000); + +template +static void BM_GraphViewTGetRegularFanin(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetRegularFanin(0); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetRegularFanin(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanin(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetRegularFanin(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanin(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanin); + +template +static void BM_GraphViewTGetRegularFanout(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetRegularFanout(0); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetRegularFanout(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanout(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetRegularFanout(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanout(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanout); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanout); + +template +static void BM_GraphViewTGetRegularFanins(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetRegularFanins(); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetRegularFanins(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanins(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetRegularFanins(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanins(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanins); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanins); + +template +static void BM_GraphViewTGetRegularFanouts(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetRegularFanouts(); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetRegularFanouts(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanouts(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetRegularFanouts(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetRegularFanouts(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanouts); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanouts); + +template +static void BM_GraphViewTGetControllingFanins(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetControllingFanins(); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetControllingFanins(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetControllingFanins(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetControllingFanins(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetControllingFanins(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControllingFanins); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControllingFanins); + +template +static void BM_GraphViewTGetControlledFanouts(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + auto* node = graph_view.GetNode("node"); + node->GetControlledFanouts(); + } + testing::StopTiming(); +} + +static void BM_GraphViewGetControlledFanouts(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetControlledFanouts(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewGetControlledFanouts(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTGetControlledFanouts(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControlledFanouts); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControlledFanouts); + +template +inline static void BM_GraphViewTHasRegularFanin(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, /*num_controlling_fanins=*/0, + /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false); + Status s; + GraphViewT graph_view(&graph_def, &s); + const int index = IsLast ? num_fanouts - 1 : 0; + auto* node = graph_view.GetNode(absl::StrFormat("out%05d", index)); + auto* fanin = graph_view.GetNode("node"); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + node->HasFanin({&graph_view, fanin->node_index(), 0}); + } + testing::StopTiming(); +} + +static void BM_GraphViewHasRegularFaninFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanin(iters, num_fanins, + num_fanouts); +} + +static void BM_GraphViewHasRegularFaninLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanin(iters, num_fanins, num_fanouts); +} + +static void BM_MutableGraphViewHasRegularFaninFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanin(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasRegularFaninLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanin(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninLast); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninLast); + +template +inline static void BM_GraphViewTHasControllingFanin(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); + Status s; + GraphViewT graph_view(&graph_def, &s); + const int index = IsLast ? num_fanouts - 1 : 0; + auto* node = graph_view.GetNode(absl::StrFormat("control_out%05d", index)); + auto* fanin = graph_view.GetNode("node"); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + node->HasFanin({&graph_view, fanin->node_index(), Graph::kControlSlot}); + } + testing::StopTiming(); +} + +static void BM_GraphViewHasControllingFaninFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControllingFanin(iters, num_fanins, + num_fanouts); +} + +static void BM_GraphViewHasControllingFaninLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControllingFanin(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasControllingFaninFirst(int iters, + int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControllingFanin(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasControllingFaninLast(int iters, + int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControllingFanin(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninLast); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninLast); + +template +inline static void BM_GraphViewTHasRegularFanout(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, /*num_controlling_fanins=*/0, + /*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false); + Status s; + GraphViewT graph_view(&graph_def, &s); + const int index = IsLast ? num_fanins - 1 : 0; + auto* node = graph_view.GetNode(absl::StrFormat("in%05d", index)); + auto* fanout = graph_view.GetNode("node"); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + node->HasFanout({&graph_view, fanout->node_index(), index}); + } + testing::StopTiming(); +} + +static void BM_GraphViewHasRegularFanoutFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_GraphViewHasRegularFanoutLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasRegularFanoutFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasRegularFanoutLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasRegularFanout(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutLast); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutLast); + +template +inline static void BM_GraphViewTHasControlledFanout(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/false); + Status s; + GraphViewT graph_view(&graph_def, &s); + const int index = IsLast ? num_fanins - 1 : 0; + auto* node = graph_view.GetNode(absl::StrFormat("control_in%05d", index)); + auto* fanout = graph_view.GetNode("node"); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + node->HasFanout({&graph_view, fanout->node_index(), Graph::kControlSlot}); + } + testing::StopTiming(); +} + +static void BM_GraphViewHasControlledFanoutFirst(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControlledFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_GraphViewHasControlledFanoutLast(int iters, int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControlledFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasControlledFanoutFirst(int iters, + int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControlledFanout(iters, num_fanins, + num_fanouts); +} + +static void BM_MutableGraphViewHasControlledFanoutLast(int iters, + int num_fanins, + int num_fanouts) { + BM_GraphViewTHasControlledFanout(iters, num_fanins, + num_fanouts); +} + +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst); +RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast); + +} // namespace +} // namespace utils +} // namespace grappler +} // namespace tensorflow