From 3eeaf9f1e16739e953c0242cc5d8cbfc40f79dca Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 31 Oct 2018 09:42:35 -0700 Subject: [PATCH] [Grappler] Preserve constness of the GraphDef in GraphView. 1. Split GrapView into GraphView and MutableGraphView with separate {Input/Output}Port types with different node pointer constness. 2. Properly use GraphView and MutableGraphView in graph properties, and get rid of const_cast. 3. Remove const_cast in function optimizer. 4. Migrate GraphView to absl containers and hash PiperOrigin-RevId: 219488040 --- tensorflow/core/grappler/BUILD | 5 + tensorflow/core/grappler/costs/BUILD | 2 +- .../core/grappler/costs/graph_properties.cc | 24 +- tensorflow/core/grappler/graph_view.cc | 212 ------------ tensorflow/core/grappler/graph_view.h | 321 ++++++++++++++---- tensorflow/core/grappler/graph_view_test.cc | 27 +- .../core/grappler/mutable_graph_view.cc | 35 +- tensorflow/core/grappler/mutable_graph_view.h | 22 +- .../core/grappler/mutable_graph_view_test.cc | 9 +- tensorflow/core/grappler/optimizers/BUILD | 9 +- .../grappler/optimizers/data/filter_fusion.cc | 2 +- .../grappler/optimizers/data/graph_utils.cc | 8 +- .../optimizers/data/graph_utils_test.cc | 46 ++- .../optimizers/data/hoist_random_uniform.cc | 8 +- .../optimizers/data/latency_all_edges.cc | 6 +- .../optimizers/data/make_numa_aware.cc | 2 +- .../optimizers/data/map_and_batch_fusion.cc | 3 +- .../data/map_and_batch_fusion_test.cc | 2 +- .../optimizers/data/map_and_filter_fusion.cc | 7 +- .../grappler/optimizers/data/map_fusion.cc | 3 +- .../optimizers/data/map_parallelization.cc | 2 +- .../optimizers/data/map_vectorization.cc | 5 +- .../optimizers/data/noop_elimination.cc | 11 +- .../data/shuffle_and_repeat_fusion_test.cc | 2 +- .../grappler/optimizers/function_optimizer.cc | 11 +- .../grappler/optimizers/loop_optimizer.cc | 40 +-- .../grappler/optimizers/memory_optimizer.cc | 40 +-- .../grappler/optimizers/shape_optimizer.cc | 20 +- tensorflow/core/grappler/utils/BUILD | 1 + tensorflow/core/grappler/utils/traversal.cc | 27 +- tensorflow/core/grappler/utils/traversal.h | 7 + .../core/grappler/utils/traversal_test.cc | 29 +- 32 files changed, 503 insertions(+), 445 deletions(-) diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 3bad29a2390..c2a9a28a1ca 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -69,6 +69,9 @@ cc_library( ":utils", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", ], ) @@ -82,6 +85,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 144d7f8ce6c..f52735fd64c 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -44,7 +44,7 @@ cc_library( "@com_google_absl//absl/memory", "//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:topological_sort", - "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 23aa5aa2103..82439d66147 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/costs/utils.h" -#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" @@ -456,10 +456,10 @@ class SymbolicShapeRefiner { const GraphView& graph, const std::unordered_map>& fed_ports) : graph_(graph), - function_library_(OpRegistry::Global(), graph.GetGraph()->library()), + function_library_(OpRegistry::Global(), graph.graph()->library()), fed_ports_(fed_ports) { - graph_def_version_ = graph.GetGraph()->versions().producer(); - node_to_context_.reserve(graph.GetGraph()->node_size()); + graph_def_version_ = graph.graph()->versions().producer(); + node_to_context_.reserve(graph.graph()->node_size()); } const GraphView& graph() const { return graph_; } @@ -512,7 +512,7 @@ class SymbolicShapeRefiner { // Placeholder with Const) don't affect one in // fun_to_grappler_function_item_. GrapplerFunctionItem grappler_function_item = it->second; - GraphView gv(&grappler_function_item.graph); + MutableGraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. for (int i = 0; i < grappler_function_item.inputs().size(); ++i) { @@ -532,7 +532,7 @@ class SymbolicShapeRefiner { "Function inputs should not contain control nodes."); } - NodeDef* input_node = graph_.GetNode(node_name); + const NodeDef* input_node = graph_.GetNode(node_name); if (input_node == nullptr) { return errors::FailedPrecondition(node_name, " was not found in the graph."); @@ -566,7 +566,7 @@ class SymbolicShapeRefiner { for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { const string& input = function_node->input(i); const string& node_name = NodeName(input); - NodeDef* input_node = graph_.GetNode(node_name); + const NodeDef* input_node = graph_.GetNode(node_name); if (IsConstant(*input_node)) { TF_CHECK_OK( ReplaceInputWithConst(*input_node, i, &grappler_function_item)); @@ -1441,8 +1441,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, continue; } ShapeHandle input = in->output(fanin.src.port_id); - CHECK_EQ(fanin.tgt.node, node); - c->SetInput(fanin.tgt.port_id, input); + CHECK_EQ(fanin.dst.node, node); + c->SetInput(fanin.dst.port_id, input); if (!out_initialized) { out_initialized = true; out = input; @@ -1673,7 +1673,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { } } - GraphView graph_view(const_cast(&item_.graph)); + GraphView graph_view(&item_.graph); // List the resources and the nodes using them. Also collect the Merge nodes, // fed nodes, and primary inputs. @@ -1725,10 +1725,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { for (const auto& resource : resources) { for (const NodeDef* src : resource.second.first) { resource_handles[src] = resource.first; - for (const NodeDef* tgt : resource.second.second) { + for (const NodeDef* dst : resource.second.second) { // Add control edges from enqueue to dequeue nodes to ensure they are // processed in their logical order. - extra_deps.emplace_back(src, tgt); + extra_deps.emplace_back(src, dst); } } } diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index de0a63fc4e3..9b3958b6c17 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { return OpPortIdToArgId(node, op.input_arg(), port_id); } -GraphView::GraphView(GraphDef* graph) : graph_(graph) { - for (int i = 0; i < graph_->node_size(); i++) { - auto node = graph_->mutable_node(i); - AddUniqueNodeOrDie(node); - } - - for (NodeDef& node : *graph_->mutable_node()) { - AddFanouts(&node); - } -} - -void GraphView::AddUniqueNodeOrDie(NodeDef* node) { - auto result = nodes_.emplace(node->name(), node); - // Check that the graph doesn't contain multiple nodes with the same name. - CHECK(result.second) << "Non unique node name detected: " << node->name(); -} - -void GraphView::AddFanouts(NodeDef* node) { - for (int i = 0; i < node->input_size(); ++i) { - OutputPort fanin; - const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); - fanin.node = nodes_[fanin_name]; - - InputPort input; - input.node = node; - if (fanin.port_id < 0) { - input.port_id = -1; - } else { - input.port_id = i; - num_regular_outputs_[fanin.node] = - std::max(num_regular_outputs_[fanin.node], fanin.port_id); - } - - fanouts_[fanin].insert(input); - } -} - -NodeDef* GraphView::GetNode(const string& node_name) const { - auto it = nodes_.find(node_name); - if (it == nodes_.end()) { - return nullptr; - } - return it->second; -} - -GraphView::InputPort GraphView::GetInputPort(const string& node_name, - int port_id) const { - InputPort result; - result.node = GetNode(node_name); - // TODO(bsteiner): verify that the node has at least port_id input ports - result.port_id = port_id; - return result; -} - -GraphView::OutputPort GraphView::GetOutputPort(const string& node_name, - int port_id) const { - OutputPort result; - result.node = GetNode(node_name); - // TODO(bsteiner): verify that the node has at least port_id output ports - result.port_id = port_id; - return result; -} - -const std::unordered_set& -GraphView::GetFanout(const GraphView::OutputPort& port) const { - auto it = fanouts_.find(port); - if (it == fanouts_.end()) { - return empty_set_; - } - return it->second; -} - -std::unordered_set -GraphView::GetFanin(const GraphView::InputPort& port) const { - std::unordered_set result; - if (port.port_id >= 0) { - result.insert(GetRegularFanin(port)); - } else { - for (int i = port.node->input_size() - 1; i >= 0; --i) { - OutputPort fanin; - string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id); - if (fanin.port_id < 0) { - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.node = it->second; - result.insert(fanin); - } - } else { - break; - } - } - } - return result; -} - -const GraphView::OutputPort GraphView::GetRegularFanin( - const GraphView::InputPort& port) const { - CHECK_LE(0, port.port_id); - OutputPort fanin; - string fanin_name = - ParseNodeName(port.node->input(port.port_id), &fanin.port_id); - auto it = nodes_.find(fanin_name); - if (it == nodes_.end()) { - fanin.node = nullptr; - } else { - fanin.node = it->second; - } - return fanin; -} - -std::unordered_set -GraphView::GetFanouts(const NodeDef& node, - bool include_controlled_nodes) const { - std::unordered_set result; - OutputPort port; - port.node = const_cast(&node); - const int first_port_id = include_controlled_nodes ? -1 : 0; - auto it = num_regular_outputs_.find(&node); - const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1; - - for (int i = first_port_id; i <= last_port_id; ++i) { - port.port_id = i; - auto it = fanouts_.find(port); - if (it != fanouts_.end()) { - result.insert(it->second.begin(), it->second.end()); - } - } - return result; -} - -std::unordered_set -GraphView::GetFanins(const NodeDef& node, - bool include_controlling_nodes) const { - std::unordered_set result; - for (int i = 0; i < node.input_size(); ++i) { - OutputPort fanin; - string fanin_name = ParseNodeName(node.input(i), &fanin.port_id); - if (fanin.port_id < 0) { - if (!include_controlling_nodes) { - break; - } - } - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.node = it->second; - result.insert(fanin); - } - } - return result; -} - -int GraphView::NumFanins(const NodeDef& node, - bool include_controlling_nodes) const { - int count = 0; - for (const string& input : node.input()) { - if (!include_controlling_nodes && IsControlInput(input)) { - break; - } - count += 1; - } - return count; -} - -std::unordered_set -GraphView::GetFanoutEdges(const NodeDef& node, - bool include_controlled_edges) const { - std::unordered_set result; - OutputPort port; - port.node = const_cast(&node); - const int first_port_id = include_controlled_edges ? -1 : 0; - auto it = num_regular_outputs_.find(&node); - const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1; - - for (int i = first_port_id; i <= last_port_id; ++i) { - port.port_id = i; - auto it = fanouts_.find(port); - if (it != fanouts_.end()) { - Edge fanout; - fanout.src.node = const_cast(&node); - fanout.src.port_id = i; - for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { - fanout.tgt = *itr; - result.insert(fanout); - } - } - } - return result; -} - -std::unordered_set -GraphView::GetFaninEdges(const NodeDef& node, - bool include_controlling_edges) const { - std::unordered_set result; - for (int i = 0; i < node.input_size(); ++i) { - Edge fanin; - fanin.tgt.node = const_cast(&node); - fanin.tgt.port_id = i; - string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id); - if (fanin.src.port_id < 0) { - if (!include_controlling_edges) { - break; - } - } - auto it = nodes_.find(fanin_name); - if (it != nodes_.end()) { - fanin.src.node = it->second; - result.insert(fanin); - } - } - return result; -} - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 09c36a13683..495e01d2ebe 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -18,9 +18,16 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -36,114 +43,290 @@ namespace grappler { int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); -// A utility class to simplify the traversal of a GraphDef. -class GraphView { +namespace internal { + +// GraphViewInternal is a helper class to simplify graph traversal. It creates +// an immutable view of the nodes and edges represented by a GraphDef protocol +// buffer. +// +// There are two public classes implementing GraphViewInternal: +// +// - GraphView: constructed from the `const GraphDef` and doesn't allow +// to mutate underlying graph via input/output ports lookup functions (ports +// have const pointers to nodes). +// +// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate +// the graph via input/output ports lookup functions (ports have non-const +// pointers to nodes), and also have couple additional functions to +// add/remove/replace nodes in the graph. +// +// --------------------------- !!! WARNING !!! --------------------------------- +// Removing nodes from the graph outside of MutableGraphView will +// lead to segfaults! Guaranteed by absl::string_view! +// ----------------------------------------------------------------------------- +// +template +class GraphViewInternal { public: struct Port { - Port() = default; - Port(NodeDef* n, int port) : node(n), port_id(port) {} - - // TODO(prazek): ports should keep the constness of GraphView. The only way - // to modify graph through the view should be using MutableGraphView. - NodeDef* node = nullptr; - int port_id = -1; + Port() : node(nullptr), port_id(0) {} + Port(NodeDefT* n, int port) : node(n), port_id(port) {} bool operator==(const Port& other) const { return node == other.node && port_id == other.port_id; } - }; - struct InputPort : public Port { - InputPort() = default; - InputPort(NodeDef* n, int port_id) : Port(n, port_id) {} - InputPort(const NodeDef* n, int port_id) - : Port(const_cast(n), port_id) {} - }; - struct OutputPort : public Port { - OutputPort() = default; - OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {} + + template + friend H AbslHashValue(H h, const Port& p) { + return H::combine(std::move(h), p.node, p.port_id); + } + + NodeDefT* node; + int port_id; }; - struct HashPort { - std::size_t operator()(const Port& port) const { - return reinterpret_cast(port.node) + port.port_id; - } + struct InputPort : public Port { + using Port::Port; + }; + + struct OutputPort : public Port { + using Port::Port; }; struct Edge { - OutputPort src; - InputPort tgt; + Edge(OutputPort s, InputPort d) : src(s), dst(d) {} bool operator==(const Edge& other) const { - return src == other.src && tgt == other.tgt; + return src == other.src && dst == other.dst; } - }; - struct HashEdge { - std::size_t operator()(const Edge& edge) const { - return HashPort()(edge.src) + HashPort()(edge.tgt); + + template + friend H AbslHashValue(H h, const Edge& e) { + return H::combine(std::move(h), e.src, e.dst); } + + OutputPort src; + InputPort dst; }; - explicit GraphView(GraphDef* graph); - GraphDef* GetGraph() const { return graph_; } - NodeDef* GetNode(const string& node_name) const; + GraphDefT* graph() const { return graph_; } + + // Find a node by name or return `nullptr` if it's not in a graph view. + NodeDefT* GetNode(absl::string_view node_name) const { + return gtl::FindWithDefault(nodes_, node_name, nullptr); + } + // Get the specified input port. Note that the special '-1' port_id can be // used to access the controlling nodes (i.e. the nodes connected to node_name // through an incoming control dependency). - InputPort GetInputPort(const string& node_name, int port_id) const; + InputPort GetInputPort(absl::string_view node_name, int port_id) const { + return InputPort(GetNode(node_name), port_id); + } + // Get the specified output port. Note that the special '-1' port_id can be // used to access the controlled nodes (i.e. the nodes connected to node_name // through an outgoing control dependency). - OutputPort GetOutputPort(const string& node_name, int port_id) const; + OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { + return OutputPort(GetNode(node_name), port_id); + } // Get the input (resp. output) port(s) in the immediate fanout (resp. fanin) // of an output (resp. input) port. - const std::unordered_set& GetFanout( - const OutputPort& port) const; - std::unordered_set GetFanin( - const InputPort& port) const; + const absl::flat_hash_set& GetFanout( + const OutputPort& port) const { + return gtl::FindWithDefault(fanouts_, port, empty_set_); + } + + absl::flat_hash_set GetFanin(const InputPort& port) const { + if (port.port_id >= 0) return {GetRegularFanin(port)}; + + // Collect fanin for the control input. + absl::flat_hash_set result; + for (int i = port.node->input_size() - 1; i >= 0; --i) { + TensorId tensor_id = ParseTensorName(port.node->input(i)); + if (tensor_id.index() >= 0) break; // we reached regular inputs + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } // Special case: regular (i.e. non-control) input ports can only have one // fanin. - const OutputPort GetRegularFanin(const InputPort& port) const; + const OutputPort GetRegularFanin(const InputPort& port) const { + DCHECK_GE(port.port_id, 0); + if (port.port_id < 0) return OutputPort(); - // Get all the input (resp. output) ports in the immediate fanout (resp fanin) - // of a node. Include the controlling nodes iff include_controlling_nodes is - // true. - std::unordered_set GetFanouts( - const NodeDef& node, bool include_controlled_nodes) const; - std::unordered_set GetFanins( - const NodeDef& node, bool include_controlling_nodes) const; + TensorId tensor_id = ParseTensorName(port.node->input(port.port_id)); + return GetOutputPort(tensor_id.node(), tensor_id.index()); + } + + // Get all the input (resp. output) ports in the immediate fanout (resp + // fanin) of a node. Include the controlling nodes iff + // include_controlling_nodes is true. + absl::flat_hash_set GetFanouts( + const NodeDef& node, bool include_controlled_nodes) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(num_regular_outputs_, port.node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + result.insert(it->second.begin(), it->second.end()); + } + } + return result; + } + + absl::flat_hash_set GetFanins( + const NodeDef& node, bool include_controlling_nodes) const { + absl::flat_hash_set result; + for (int i = 0; i < node.input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + if (tensor_id.index() < 0 && !include_controlling_nodes) break; + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); + } + return result; + } // Get the number of ports in the immediate fanin of a node. Count the // controlling nodes iff include_controlling_nodes is true. - int NumFanins(const NodeDef& node, bool include_controlling_nodes) const; + int NumFanins(const NodeDef& node, bool include_controlling_nodes) const { + int count = 0; + for (const string& input : node.input()) { + if (!include_controlling_nodes && IsControlInput(input)) { + break; + } + count += 1; + } + return count; + } - // Get all the edge in the immediate fanout (resp fanin) of a node. Include - // the control edges iff include_controlling_edges is true. - std::unordered_set GetFanoutEdges( - const NodeDef& node, bool include_controlled_edges) const; - std::unordered_set GetFaninEdges( - const NodeDef& node, bool include_controlling_edges) const; + // Get the number of ports in the immediate fanout of a node. Count the + // controlling nodes iff include_controlling_nodes is true. + int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const { + int count = 0; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlling_nodes ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(num_regular_outputs_, port.node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) count += it->second.size(); + } + + return count; + } + + // Get all the edges in the immediate fanout (resp fanin) of a node. + // Include the control edges iff include_controlling_edges is true. + absl::flat_hash_set GetFanoutEdges( + const NodeDef& node, bool include_controlled_edges) const { + absl::flat_hash_set result; + + OutputPort port; + port.node = const_cast(&node); + const int first_port_id = include_controlled_edges ? -1 : 0; + const int last_port_id = + gtl::FindWithDefault(num_regular_outputs_, &node, -1); + + for (int i = first_port_id; i <= last_port_id; ++i) { + port.port_id = i; + auto it = fanouts_.find(port); + if (it != fanouts_.end()) { + for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { + result.emplace(/*src*/ OutputPort(const_cast(&node), i), + /*dst*/ *itr); + } + } + } + return result; + } + + absl::flat_hash_set GetFaninEdges( + const NodeDef& node, bool include_controlling_edges) const { + absl::flat_hash_set result; + for (int i = 0; i < node.input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node.input(i)); + if (tensor_id.index() < 0 && !include_controlling_edges) break; + + auto it = nodes_.find(tensor_id.node()); + if (it != nodes_.end()) { + result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()), + /*dst*/ InputPort(const_cast(&node), i)); + } + } + return result; + } protected: - // Add a new `node` to the graph. - void AddUniqueNodeOrDie(NodeDef* node); - // Add fanout to every `node` input. - void AddFanouts(NodeDef* node); - std::unordered_map* MutableNodes() { return &nodes_; } - GraphDef* MutableGraph() { return graph_; } + explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} - using FanoutsMapType = - std::unordered_map, - HashPort>; - FanoutsMapType* MutableFanouts() { return &fanouts_; } + void AddUniqueNodeOrDie(NodeDefT* node) { + auto result = nodes_.emplace(node->name(), node); + // TODO(ezhulenev): Replace CHECK with factory method returning + // absl::StatusOr (when available). + CHECK(result.second) << "Non unique node name detected: " << node->name(); + } + + void AddFanouts(NodeDefT* node) { + for (int i = 0; i < node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node->input(i)); + OutputPort output(nodes_[tensor_id.node()], tensor_id.index()); + + if (output.port_id < 0) { + fanouts_[output].emplace(node, -1); + } else { + num_regular_outputs_[output.node] = + std::max(num_regular_outputs_[output.node], output.port_id); + fanouts_[output].emplace(node, i); + } + } + } + + // Access to the mutable internal state for MutableGraphView. + absl::flat_hash_map* mutable_nodes() { + return &nodes_; + } + + absl::flat_hash_map>* + mutable_fanouts() { + return &fanouts_; + } private: - GraphDef* graph_; - std::unordered_map nodes_; - std::unordered_set empty_set_; - FanoutsMapType fanouts_; - std::unordered_map num_regular_outputs_; + GraphDefT* graph_; // must outlive the graph view + absl::flat_hash_map nodes_; + absl::flat_hash_set empty_set_; + absl::flat_hash_map> fanouts_; + std::unordered_map num_regular_outputs_; +}; + +} // namespace internal + +// Immutable GraphView that keeps the constness of the GraphDef. If you need to +// mutate the graph or the nodes via the graph view lookup functions, see +// MutableGraphView. +class GraphView + : public internal::GraphViewInternal { + public: + explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) { + for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node); + for (const NodeDef& node : graph->node()) AddFanouts(&node); + } }; } // end namespace grappler diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index f90e2c8cfcd..cbf859a4a99 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) { const NodeDef* add_node = graph.GetNode("AddN"); EXPECT_NE(nullptr, add_node); - string fanouts; - for (const auto& fo : graph.GetFanouts(*add_node, false)) { - strings::StrAppend(&fanouts, - strings::StrCat(fo.node->name(), ":", fo.port_id, " ")); - } - EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts); - string fanins; - for (const auto& fi : graph.GetFanins(*add_node, false)) { - strings::StrAppend(&fanins, - strings::StrCat(fi.node->name(), ":", fi.port_id, " ")); + absl::flat_hash_set fanouts; + absl::flat_hash_set expected_fanouts = {"AddN_2:0", "AddN_3:0"}; + for (const auto& fo : graph.GetFanouts(*add_node, false)) { + fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id)); } - EXPECT_EQ("Square_1:0 Square:0 ", fanins); + EXPECT_EQ(graph.NumFanouts(*add_node, false), 2); + EXPECT_EQ(fanouts, expected_fanouts); + + absl::flat_hash_set fanins; + absl::flat_hash_set expected_fanins = {"Square_1:0", "Square:0"}; + for (const auto& fi : graph.GetFanins(*add_node, false)) { + fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id)); + } + EXPECT_EQ(graph.NumFanins(*add_node, false), 2); + EXPECT_EQ(fanins, expected_fanins); } TEST_F(GraphViewTest, ControlDependencies) { diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index f0aff90c6c2..67e804cbfbd 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -19,8 +19,26 @@ limitations under the License. namespace tensorflow { namespace grappler { +const absl::flat_hash_set& +MutableGraphView::GetFanout(const GraphView::OutputPort& port) const { + return GetFanout(MutableGraphView::OutputPort(const_cast(port.node), + port.port_id)); +} + +absl::flat_hash_set MutableGraphView::GetFanin( + const GraphView::InputPort& port) const { + return GetFanin(MutableGraphView::InputPort(const_cast(port.node), + port.port_id)); +} + +const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin( + const GraphView::InputPort& port) const { + return GetRegularFanin(MutableGraphView::InputPort( + const_cast(port.node), port.port_id)); +} + NodeDef* MutableGraphView::AddNode(NodeDef&& node) { - auto* node_in_graph = GetGraph()->add_node(); + auto* node_in_graph = graph()->add_node(); *node_in_graph = std::move(node); AddUniqueNodeOrDie(node_in_graph); @@ -31,7 +49,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) { NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node, const int output_port_id) { - auto* node_in_graph = GetGraph()->add_node(); + auto* node_in_graph = graph()->add_node(); *node_in_graph = std::move(node); AddUniqueNodeOrDie(node_in_graph); @@ -46,8 +64,7 @@ NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node, void MutableGraphView::ReplaceInput(const NodeDef& old_input, const NodeDef& new_input, const int output_port_id) { - GraphView::OutputPort output_port = - GetOutputPort(old_input.name(), output_port_id); + OutputPort output_port = GetOutputPort(old_input.name(), output_port_id); auto fanout = GetFanout(output_port); for (auto& input_port : fanout) { input_port.node->set_input(input_port.port_id, new_input.name()); @@ -57,17 +74,17 @@ void MutableGraphView::ReplaceInput(const NodeDef& old_input, void MutableGraphView::DeleteNodes(const std::set& nodes_to_delete) { for (const string& node_name_to_delete : nodes_to_delete) - RemoveFanouts(MutableNodes()->at(node_name_to_delete)); + RemoveFanouts(mutable_nodes()->at(node_name_to_delete)); for (const string& node_name_to_delete : nodes_to_delete) - MutableNodes()->erase(node_name_to_delete); - EraseNodesFromGraph(nodes_to_delete, GetGraph()); + mutable_nodes()->erase(node_name_to_delete); + EraseNodesFromGraph(nodes_to_delete, graph()); } void MutableGraphView::RemoveFanouts(NodeDef* node) { for (int i = 0; i < node->input_size(); ++i) { OutputPort fanin; string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); - fanin.node = (*MutableNodes())[fanin_name]; + fanin.node = (*mutable_nodes())[fanin_name]; InputPort input; input.node = node; @@ -76,7 +93,7 @@ void MutableGraphView::RemoveFanouts(NodeDef* node) { else input.port_id = i; - (*MutableFanouts())[fanin].erase(input); + (*mutable_fanouts())[fanin].erase(input); } } diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index 971e5503d4c..702751a57fd 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -24,11 +24,25 @@ namespace grappler { // A utility class to simplify the traversal of a GraphDef that, unlike // GraphView, supports updating the graph. Note that you should not modify the // graph separately, because the view will get out of sync. -class MutableGraphView : public GraphView { - public: - using GraphView::GraphView; - GraphDef* GetGraph() { return MutableGraph(); } +class MutableGraphView : public internal::GraphViewInternal { + public: + explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) { + for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node); + for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node); + } + + // Lookup fanouts/fanins using immutable ports. + using GraphViewInternal::GetFanout; + const absl::flat_hash_set& GetFanout( + const GraphView::OutputPort& port) const; + + using GraphViewInternal::GetFanin; + absl::flat_hash_set GetFanin( + const GraphView::InputPort& port) const; + + using GraphViewInternal::GetRegularFanin; + const OutputPort GetRegularFanin(const GraphView::InputPort& port) const; // Adds a new node to graph and updates the view. NodeDef* AddNode(NodeDef&& node); diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index 2536bec35dd..7d9025e031e 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -26,7 +26,8 @@ namespace { bool FindChildWithName(const MutableGraphView& graph, const string& output_port_name, const string& input_name) { - GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0); + MutableGraphView::OutputPort output_port = + graph.GetOutputPort(output_port_name, 0); auto fanout = graph.GetFanout(output_port); for (auto& input_port : fanout) { if (input_port.node->name() == input_name) return true; @@ -59,10 +60,10 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { GraphDef new_graph = item.graph; MutableGraphView graph(&new_graph); - GraphView::InputPort input = graph.GetInputPort("AddN", 0); + MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0); EXPECT_EQ("AddN", input.node->name()); EXPECT_EQ(0, input.port_id); - GraphView::OutputPort fanin = graph.GetRegularFanin(input); + MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input); EXPECT_EQ("Square", fanin.node->name()); EXPECT_EQ(0, fanin.port_id); @@ -89,7 +90,7 @@ TEST(MutableGraphViewTest, InsertNodes) { GraphDef new_graph = item.graph; MutableGraphView graph(&new_graph); - GraphView::InputPort input = graph.GetInputPort("AddN", 0); + MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0); NodeDef new_node = *input.node; new_node.set_name("new_node"); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 127c1603ba4..0637e3b2e15 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -145,8 +145,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/utils:functions", @@ -422,8 +422,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", @@ -625,12 +625,13 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:frame", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -663,8 +664,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index 1ad495bbad0..3ffbfba95ee 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(), &fused_node); fused_node.set_op("FilterDataset"); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index b863a25dc5f..90208c1fba6 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper( MutableGraphView* graph) { NodeDef node; node.set_op(kConstOpName); - SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node); + SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node); (*node.mutable_attr())["dtype"].set_type(dtype); std::unique_ptr tensor = @@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper( NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { NodeDef node; node.set_op("Placeholder"); - SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); + SetUniqueGraphNodeName(node.op(), graph->graph(), &node); (*node.mutable_attr())["dtype"].set_type(dtype); TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); shape->set_unknown_rank(false); @@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op, if (!name.empty()) { node.set_name(string(name)); } else { - SetUniqueGraphNodeName(op, graph->GetGraph(), &node); + SetUniqueGraphNodeName(op, graph->graph(), &node); } node.set_op(string(op)); for (const string& input : inputs) { @@ -228,7 +228,7 @@ std::vector FindAllGraphNodesWithOp(const string& op, NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { if (node.input_size() == 0) return nullptr; - GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); + MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); return graph.GetRegularFanin(input_port).node; } diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 4ab6d71532c..5c0f03dca87 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* bool_node = AddScalarConstNode(true, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph())); EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); } @@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* double_node = AddScalarConstNode(3.14, &graph); - EXPECT_TRUE( - ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph())); EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); } @@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* float_node = AddScalarConstNode(3.14, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph())); EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); } @@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* int_node = AddScalarConstNode(42, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph())); EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); } @@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* int64_node = AddScalarConstNode(42, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph())); EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); } @@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) { GraphDef graph_def; MutableGraphView graph(&graph_def); NodeDef* string_node = AddScalarConstNode("hello", &graph); - EXPECT_TRUE( - ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph())); EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); } @@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) { TEST(GraphUtilsTest, ContainsGraphNodeWithName) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph())); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph())); graph.DeleteNodes({"A"}); - EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); + EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph())); } TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { @@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { TEST(GraphUtilsTest, ContainsNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph())); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph())); graph.DeleteNodes({"A"}); - EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); + EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph())); } TEST(GraphUtilsTest, FindGraphNodeWithName) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1); graph.DeleteNodes({"A"}); - EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1); } TEST(GraphUtilsTest, FindGraphFunctionWithName) { @@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) { TEST(GraphUtilsTest, FindGraphNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0); graph.DeleteNodes({"B"}); - EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1); - EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1); + EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1); } TEST(GraphUtilsTest, FindAllGraphNodesWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); - EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1); AddNode("A", "OpA", {}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph); std::vector result_indices = - FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + FindAllGraphNodesWithOp("OpA", *graph.graph()); EXPECT_EQ(result_indices.size(), 2); EXPECT_EQ(result_indices.at(0), 0); EXPECT_EQ(result_indices.at(1), 2); graph.DeleteNodes({"A2"}); std::vector result_indices_new = - FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + FindAllGraphNodesWithOp("OpA", *graph.graph()); EXPECT_EQ(result_indices_new.size(), 1); EXPECT_EQ(result_indices_new.at(0), 0); } diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc index ce0b2db0396..91b3f71e9e3 100644 --- a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node, const FunctionDef& stateless_function, MutableGraphView* graph) { NodeDef stateless_map; - graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(), &stateless_map); stateless_map.set_op("MapDataset"); @@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, MutableGraphView* graph) { NodeDef random_dataset; random_dataset.set_op("RandomDataset"); - graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(), &random_dataset); const auto* seed = graph_utils::AddScalarConstNode( @@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { NodeDef batch_dataset; batch_dataset.set_op("BatchDatasetV2"); - graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(), &batch_dataset); const auto* batch_size = graph_utils::AddScalarConstNode(2, graph); const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); @@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, MutableGraphView* graph) { NodeDef zip_node; - graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(), &zip_node); zip_node.set_op("ZipDataset"); diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc index 9e382aeef9c..6a5a70e084e 100644 --- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc @@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kInsertOpName); graph_utils::SetUniqueGraphNodeName( - strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(), - &new_node); + strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node); // Set the input of LatencyDataset node as `node` new_node.add_input(node.name()); @@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item, // node corresponds to a `Dataset` op. continue; } - GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0); + MutableGraphView::OutputPort output_port = + graph.GetOutputPort(node.name(), 0); auto fanout = graph.GetFanout(output_port); if (fanout.size() > 1) { LOG(WARNING) << node.name() << " has fanout size " << fanout.size(); diff --git a/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc b/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc index f9d7d027c12..bab2c361494 100644 --- a/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc +++ b/tensorflow/core/grappler/optimizers/data/make_numa_aware.cc @@ -29,7 +29,7 @@ namespace { NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) { NodeDef numa_aware_node = node; - graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(), &numa_aware_node); numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset"); return numa_aware_node; diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index e66766eb23b..2807e0886bb 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, MutableGraphView* graph) { NodeDef new_node; new_node.set_op(kFusedOpName); - graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(), - &new_node); + graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node); // Set the `input` input argument. new_node.add_input(map_node.input(0)); diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index b676246b318..eed558de7eb 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); + EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index c4868eacbbf..7cb52c36b2d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -37,8 +37,7 @@ NodeDef MakeFusedNode(const NodeDef& map_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), - &fused_node); + graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node); fused_node.set_op("MapDataset"); fused_node.add_input(map_node.input(0)); @@ -72,8 +71,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node, const NodeDef& filter_node, MutableGraphView* graph) { NodeDef filter_by_component; - graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", - graph->GetGraph(), &filter_by_component); + graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(), + &filter_by_component); filter_by_component.set_op("FilterByLastComponentDataset"); filter_by_component.add_input(fused_map_node.name()); diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index bd943342e80..23bb49db62b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, const FunctionDef& fused_function, MutableGraphView* graph) { NodeDef fused_node; - graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), - &fused_node); + graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node); fused_node.set_op("MapDataset"); fused_node.add_input(parent_map_node.input(0)); diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index 782c9f48b74..f4c86174571 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function, NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { NodeDef parallel_map = map_node; - graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(), ¶llel_map); parallel_map.set_op("ParallelMapDataset"); // TODO(b/114475558): We want to set `num_parallel_calls` to a special value, diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 0576d075c25..04ab46885cf 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node, MutableGraphView* graph) { NodeDef batch_node; batch_node.set_op(old_batch_node.op()); - graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(), + graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(), &batch_node); // Set the `input_dataset` input argument @@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node, MutableGraphView* graph) { NodeDef map_node; map_node.set_op(old_map_node.op()); - graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(), - &map_node); + graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node); // Set the `input_dataset` input argument map_node.add_input(new_batch_node.name()); diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index e47e91a375b..763434b6136 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -30,7 +30,7 @@ namespace tensorflow { namespace grappler { namespace { -bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { +bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) { if (take_node.op() != "TakeDataset") return false; const auto& count_node = *graph.GetNode(take_node.input(1)); @@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) { return node.attr().at("value").tensor().int64_val(0) == value; } -bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { +bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) { if (skip_node.op() != "SkipDataset") return false; // We are looking only for skip(0) nodes. return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); } -bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { +bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) { if (repeat_node.op() != "RepeatDataset") return false; // We are looking only for repeat(1) nodes. return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); } -bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) { +bool IsPrefetchZero(const NodeDef& prefetch_node, + const MutableGraphView& graph) { if (prefetch_node.op() != "PrefetchDataset") return false; // We are looking only for prefetch(0) nodes. return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0); } -bool IsNoOp(const NodeDef& node, const GraphView& graph) { +bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) { return IsTakeAll(node, graph) || IsSkipNone(node, graph) || IsRepeatOne(node, graph) || IsPrefetchZero(node, graph); } diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc index f0696eb76d0..556e1d3ab57 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc @@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); + EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index cc04ed3340b..f99826ddcad 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -31,8 +31,8 @@ limitations under the License. #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" @@ -219,8 +219,7 @@ class FunctionOptimizerContext { : grappler_item_id_(item.id), graph_version_(item.graph.versions().producer()), function_library_(OpRegistry::Global(), item.graph.library()), - // GraphView doesn't not modify the graph or the nodes. - graph_view_(const_cast(&item.graph)) { + graph_view_(&item.graph) { InitializeTrulyConstNodes(item); InitializeInlinedFunctions(opt_level, item); InitializeFetchNodes(item); @@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Function specialization might change the number of function outputs, so we // have to process the final optimized graph and update all the node mapping. if (ctx.RequiresOutputMapping()) { - GraphView optimized_graph_view(optimized_graph); + MutableGraphView optimized_graph_view(optimized_graph); for (const auto& output_mapping : ctx.output_mappings()) { const auto& node_name = output_mapping.first; const auto& mappings = output_mapping.second; @@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, int to = mapping.second; // Get the output port corresponding to the old output position. - GraphView::OutputPort from_port = + MutableGraphView::OutputPort from_port = optimized_graph_view.GetOutputPort(node_name, from); // Update all input ports that read from old output port. - for (GraphView::InputPort to_port : + for (MutableGraphView::InputPort to_port : optimized_graph_view.GetFanout(from_port)) { *to_port.node->mutable_input(to_port.port_id) = strings::StrCat(node_name, ":", to); diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index c74a9409494..775fb9a95f2 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -29,8 +30,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" @@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node, return Status::OK(); } -Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node, - const NodeMap& node_map, DeviceBase* cpu_device, - ResourceMgr* resource_mgr, bool* has_dead_fanout, - int* dead_fanout) { +Status CheckForDeadFanout(const MutableGraphView& view, + const NodeDef& switch_node, const NodeMap& node_map, + DeviceBase* cpu_device, ResourceMgr* resource_mgr, + bool* has_dead_fanout, int* dead_fanout) { *has_dead_fanout = false; GraphView::InputPort switch_loopcond_port(&switch_node, 1); - NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node; + const NodeDef* switch_predicate = + view.GetRegularFanin(switch_loopcond_port).node; // CASE 1: Control is a constant. if (IsConstant(*switch_predicate)) { @@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node, } GraphView::InputPort switch_input_port(&switch_node, 0); - NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node; + const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node; // CASE 2: Zero-iteration while loop. // We check if its a while loop such that the condition is a simple binary @@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches( std::unordered_map> dead_merge_inputs; // TODO(bsteiner): also rewrite switches as identity. For now we just record // them - std::unordered_set - identity_switches; + absl::flat_hash_set identity_switches; - GraphView view(optimized_graph); + MutableGraphView view(optimized_graph); for (const NodeDef& node : optimized_graph->node()) { if (!IsSwitch(node)) { continue; @@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches( if (!has_dead_fanout) { continue; } - GraphView::OutputPort dead(const_cast(&node), dead_fanout); + GraphView::OutputPort dead(&node, dead_fanout); identity_switches.insert(dead); - SetVector zombie_inputs; - for (const GraphView::InputPort& port : view.GetFanout(dead)) { + SetVector> + zombie_inputs; + for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) { if (dead_nodes.find(port.node) == dead_nodes.end()) { zombie_inputs.PushBack(port); } @@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches( dead_merge_inputs; bool found_node_to_preserve = false; while (!found_node_to_preserve && !zombie_inputs.Empty()) { - GraphView::InputPort dead = zombie_inputs.PopBack(); + MutableGraphView::InputPort dead = zombie_inputs.PopBack(); if (nodes_to_preserve.find(dead.node->name()) != nodes_to_preserve.end()) { found_node_to_preserve = true; @@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches( found_node_to_preserve = true; break; } - GraphView::OutputPort value_index(dead.node, 1); - const std::unordered_set& - index_fanout = view.GetFanout(value_index); + MutableGraphView::OutputPort value_index(dead.node, 1); + const absl::flat_hash_set& index_fanout = + view.GetFanout(value_index); if (!index_fanout.empty()) { // The 2nd output (that indicates which input is propagated) is // connected. This never happens in practice, so we'll just skip this @@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches( } if (fully_dead) { local_dead_nodes.insert(dead.node); - for (const GraphView::InputPort& port : + for (const MutableGraphView::InputPort& port : view.GetFanouts(*dead.node, true)) { zombie_inputs.PushBack(port); } @@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches( break; } else { if (local_dead_nodes.insert(dead.node).second) { - for (const GraphView::InputPort& dead_fanout : + for (const MutableGraphView::InputPort& dead_fanout : view.GetFanouts(*dead.node, true)) { zombie_inputs.PushBack(dead_fanout); } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index c36dc65bb04..e0a913565fc 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/utils.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h" @@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { // Look for AddN nodes (and equivalent) and record input names. - GraphView view(&item->graph); + MutableGraphView view(&item->graph); std::unordered_map> addn_list; for (NodeDef& node : *item->graph.mutable_node()) { @@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { for (int i = 0; i < node->input_size(); ++i) { const string& input = node->input(i); const string node_name = NodeName(input); - NodeDef* node = view.GetNode(node_name); + const NodeDef* node = view.GetNode(node_name); input_topo_index.push_back(topo_order.at(node)); } int min_input_topo_index = INT_MAX; @@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger( return nullptr; } -static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { +static bool IsSwappable(const MutableGraphView& graph, + MutableGraphView::OutputPort output) { const NodeDef& node = *output.node; // There is no point in swapping out persistent tensors, since the tensor will // continue to use memory. @@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { // If placed on the same device, these nodes are just forwarding references // to their input. Therefore they are swappable iff their fanin is swappable // or it resides on a different device. - GraphView::InputPort input; + MutableGraphView::InputPort input; input.node = output.node; input.port_id = 0; - GraphView::OutputPort fanin = graph.GetRegularFanin(input); + MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input); if (fanin.node->device() == node.device()) { return IsSwappable(graph, fanin); } @@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { } static NodeDef* FindSwapOutTrigger( - const NodeDef* node, int input_id, const GraphView& view, + const NodeDef* node, int input_id, const MutableGraphView& view, const std::unordered_map& execution_times) { // Find the output port that generated the tensor to swap. - GraphView::InputPort swap; + MutableGraphView::InputPort swap; swap.node = const_cast(node); swap.port_id = input_id; - GraphView::OutputPort generator = view.GetRegularFanin(swap); + MutableGraphView::OutputPort generator = view.GetRegularFanin(swap); if (!generator.node) { return nullptr; } - const std::unordered_set& fanout = + const absl::flat_hash_set& fanout = view.GetFanout(generator); NodeDef* trigger = nullptr; Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); @@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger( return trigger; } -static bool IsSwappable(GraphView::InputPort input) { +static bool IsSwappable(MutableGraphView::InputPort input) { const NodeDef& node = *input.node; const OpDef* op_def; @@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) { } struct MemInfo { - GraphView::OutputPort port; + MutableGraphView::OutputPort port; int64 memory_used; - std::vector uses_left; + std::vector uses_left; double fitness; bool operator<(const MemInfo& other) const { return fitness < other.fitness; } @@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates( std::vector mem_state; - GraphView graph(&item->graph); + MutableGraphView graph(&item->graph); for (const auto& live_tensor : mem_usage.live_tensors) { if (live_tensor.memory_used <= 1024) { // Don't bother with small tensors. @@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates( if (skip_list->find(live_tensor.node) != skip_list->end()) { continue; } - GraphView::OutputPort port = + MutableGraphView::OutputPort port = graph.GetOutputPort(live_tensor.node, live_tensor.output_id); if (!IsSwappable(graph, port)) { continue; @@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates( Costs::Duration allocation_time = live_tensor.allocation_time; Costs::Duration earliest_use(Costs::Duration::infinity()); bool valid = true; - for (GraphView::InputPort input : graph.GetFanout(port)) { + for (MutableGraphView::InputPort input : graph.GetFanout(port)) { // Get execution time. auto it = op_completion_times.find(input.node->name()); if (it == op_completion_times.end()) { @@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates( // the values do not fit into any integral type. mem_info.fitness = MathUtil::IPow((earliest_use - peak_time).count(), 2) / - MathUtil::IPow(mem_info.uses_left.size(), 2) + + MathUtil::IPow(mem_info.uses_left.size(), 2) + MathUtil::IPow((allocation_time - peak_time).count(), 2); mem_info.fitness = -mem_info.fitness; mem_state.push_back(mem_info); @@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates( std::sort(mem_state.begin(), mem_state.end()); for (const MemInfo& mem_info : mem_state) { - for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { + for (const MutableGraphView::InputPort fanout_to_swap : + mem_info.uses_left) { VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" << fanout_to_swap.port_id << " of tensor " << mem_info.port.node->name() << ":" << mem_info.port.port_id @@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, for (const auto& node : item->graph.node()) { name_map[node.name()] = &node; } - GraphView view(&item->graph); + MutableGraphView view(&item->graph); bool updated_graph = false; diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 6ccb1cd783d..7dae0e3cd9e 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h" @@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphProperties properties(item); bool inferred_properties = false; - GraphView graph(optimized_graph); + MutableGraphView graph(optimized_graph); // The product of all the dimensions in a tensor shape can be expressed more // simply as the size of the tensor. @@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsShape(node)) { continue; } - for (GraphView::InputPort fanout : - graph.GetFanout(GraphView::OutputPort(&node, 0))) { + for (MutableGraphView::InputPort fanout : + graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) { if (fanout.node->op() != "Prod") { continue; } @@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // rewrite the whole expression directly as a Size operation. continue; } - const GraphView::OutputPort reduce_indices = - graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); + const MutableGraphView::OutputPort reduce_indices = + graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); if (!inferred_properties) { // Infer properties lazily in case they are not needed. TF_RETURN_IF_ERROR(properties.InferStatically(false)); @@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // is possible whenever the symbolic dimensions in the numerator and // denominator cancel each other. if (node.op() == "Div") { - const GraphView::OutputPort input1 = - graph.GetRegularFanin(GraphView::InputPort(&node, 0)); - const GraphView::OutputPort input2 = - graph.GetRegularFanin(GraphView::InputPort(&node, 1)); + const MutableGraphView::OutputPort input1 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0)); + const MutableGraphView::OutputPort input2 = + graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1)); if (!IsSize(*input1.node) || !IsSize(*input2.node)) { continue; } diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 2efefe44149..dbe425b75fd 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -101,6 +101,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:mutable_graph_view", "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/grappler/utils/traversal.cc b/tensorflow/core/grappler/utils/traversal.cc index e5b2d17ae55..69522775686 100644 --- a/tensorflow/core/grappler/utils/traversal.cc +++ b/tensorflow/core/grappler/utils/traversal.cc @@ -21,8 +21,11 @@ limitations under the License. namespace tensorflow { namespace grappler { -void ReverseDfs( - const GraphView& graph_view, const std::vector& from, +namespace { + +template +void ReverseDfsInternal( + const GraphViewType& graph_view, const std::vector& from, const std::function& pre_order, const std::function& post_order, const std::function& on_back_edge) { @@ -79,5 +82,25 @@ void ReverseDfs( } } +} // namespace + +void ReverseDfs( + const GraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + ReverseDfsInternal(graph_view, from, pre_order, post_order, + on_back_edge); +} + +void ReverseDfs( + const MutableGraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge) { + ReverseDfsInternal(graph_view, from, pre_order, post_order, + on_back_edge); +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal.h b/tensorflow/core/grappler/utils/traversal.h index 8aa97237cc2..5b7737f97eb 100644 --- a/tensorflow/core/grappler/utils/traversal.h +++ b/tensorflow/core/grappler/utils/traversal.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" namespace tensorflow { namespace grappler { @@ -34,6 +35,12 @@ void ReverseDfs( const std::function& post_order, const std::function& on_back_edge); +void ReverseDfs( + const MutableGraphView& graph_view, const std::vector& from, + const std::function& pre_order, + const std::function& post_order, + const std::function& on_back_edge); + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/traversal_test.cc b/tensorflow/core/grappler/utils/traversal_test.cc index fad26b5a9e3..c040477a089 100644 --- a/tensorflow/core/grappler/utils/traversal_test.cc +++ b/tensorflow/core/grappler/utils/traversal_test.cc @@ -14,9 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/traversal.h" -//#include "tensorflow/core/framework/node_def.pb.h" -//#include "tensorflow/core/lib/core/status_test_util.h" -//#include "tensorflow/core/platform/protobuf.h" + #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) { found_back_edge = true; }); - EXPECT_EQ(std::vector({"1", "4", "3", "2", "5", "0"}), pre_order); - EXPECT_EQ(std::vector({"4", "5", "2", "3", "1", "0"}), post_order); + // Pre/Post order traversals are non deterministic because a node fanin is an + // absl::flat_hash_set with non deterministic traversal order. + using ValidTraversal = std::pair, std::vector>; + + std::set valid_traversals = { + // pre_order post_order + {{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}}, + {{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}}; + + EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1); EXPECT_FALSE(found_back_edge); } @@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) { back_edges.push_back(strings::StrCat(src->name(), "->", dst->name())); }); - EXPECT_EQ(std::vector({"6", "3", "2", "1", "5", "4"}), pre_order); - EXPECT_EQ(std::vector({"1", "4", "5", "2", "3", "6"}), post_order); + // Pre/Post order traversals are non deterministic because a node fanin is an + // absl::flat_hash_set with non deterministic traversal order. + using ValidTraversal = std::pair, std::vector>; + + std::set valid_traversals = { + // pre_order post_order + {{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}}, + {{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}}, + {{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}}; + + EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1); EXPECT_EQ(std::vector({"4->3"}), back_edges); }