[Grappler] Preserve constness of the GraphDef in GraphView.

1. Split GrapView into GraphView and MutableGraphView with separate {Input/Output}Port types with different node pointer constness.

2. Properly use GraphView and MutableGraphView in graph properties, and get rid of const_cast.

3. Remove const_cast in function optimizer.

4. Migrate GraphView to absl containers and hash

PiperOrigin-RevId: 219488040
This commit is contained in:
Eugene Zhulenev 2018-10-31 09:42:35 -07:00 committed by TensorFlower Gardener
parent 92e604060a
commit 3eeaf9f1e1
32 changed files with 503 additions and 445 deletions

View File

@ -69,6 +69,9 @@ cc_library(
":utils", ":utils",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
], ],
) )
@ -82,6 +85,8 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -44,7 +44,7 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core:core_cpu_base", "//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/functions.h"
@ -456,10 +456,10 @@ class SymbolicShapeRefiner {
const GraphView& graph, const GraphView& graph,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports) const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
: graph_(graph), : graph_(graph),
function_library_(OpRegistry::Global(), graph.GetGraph()->library()), function_library_(OpRegistry::Global(), graph.graph()->library()),
fed_ports_(fed_ports) { fed_ports_(fed_ports) {
graph_def_version_ = graph.GetGraph()->versions().producer(); graph_def_version_ = graph.graph()->versions().producer();
node_to_context_.reserve(graph.GetGraph()->node_size()); node_to_context_.reserve(graph.graph()->node_size());
} }
const GraphView& graph() const { return graph_; } const GraphView& graph() const { return graph_; }
@ -512,7 +512,7 @@ class SymbolicShapeRefiner {
// Placeholder with Const) don't affect one in // Placeholder with Const) don't affect one in
// fun_to_grappler_function_item_. // fun_to_grappler_function_item_.
GrapplerFunctionItem grappler_function_item = it->second; GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph); MutableGraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes. // Forward shapes from function input nodes to argument nodes.
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) { for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
@ -532,7 +532,7 @@ class SymbolicShapeRefiner {
"Function inputs should not contain control nodes."); "Function inputs should not contain control nodes.");
} }
NodeDef* input_node = graph_.GetNode(node_name); const NodeDef* input_node = graph_.GetNode(node_name);
if (input_node == nullptr) { if (input_node == nullptr) {
return errors::FailedPrecondition(node_name, return errors::FailedPrecondition(node_name,
" was not found in the graph."); " was not found in the graph.");
@ -566,7 +566,7 @@ class SymbolicShapeRefiner {
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i); const string& input = function_node->input(i);
const string& node_name = NodeName(input); const string& node_name = NodeName(input);
NodeDef* input_node = graph_.GetNode(node_name); const NodeDef* input_node = graph_.GetNode(node_name);
if (IsConstant(*input_node)) { if (IsConstant(*input_node)) {
TF_CHECK_OK( TF_CHECK_OK(
ReplaceInputWithConst(*input_node, i, &grappler_function_item)); ReplaceInputWithConst(*input_node, i, &grappler_function_item));
@ -1441,8 +1441,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
continue; continue;
} }
ShapeHandle input = in->output(fanin.src.port_id); ShapeHandle input = in->output(fanin.src.port_id);
CHECK_EQ(fanin.tgt.node, node); CHECK_EQ(fanin.dst.node, node);
c->SetInput(fanin.tgt.port_id, input); c->SetInput(fanin.dst.port_id, input);
if (!out_initialized) { if (!out_initialized) {
out_initialized = true; out_initialized = true;
out = input; out = input;
@ -1673,7 +1673,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
} }
} }
GraphView graph_view(const_cast<GraphDef*>(&item_.graph)); GraphView graph_view(&item_.graph);
// List the resources and the nodes using them. Also collect the Merge nodes, // List the resources and the nodes using them. Also collect the Merge nodes,
// fed nodes, and primary inputs. // fed nodes, and primary inputs.
@ -1725,10 +1725,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
for (const auto& resource : resources) { for (const auto& resource : resources) {
for (const NodeDef* src : resource.second.first) { for (const NodeDef* src : resource.second.first) {
resource_handles[src] = resource.first; resource_handles[src] = resource.first;
for (const NodeDef* tgt : resource.second.second) { for (const NodeDef* dst : resource.second.second) {
// Add control edges from enqueue to dequeue nodes to ensure they are // Add control edges from enqueue to dequeue nodes to ensure they are
// processed in their logical order. // processed in their logical order.
extra_deps.emplace_back(src, tgt); extra_deps.emplace_back(src, dst);
} }
} }
} }

View File

@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
return OpPortIdToArgId(node, op.input_arg(), port_id); return OpPortIdToArgId(node, op.input_arg(), port_id);
} }
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
AddUniqueNodeOrDie(node);
}
for (NodeDef& node : *graph_->mutable_node()) {
AddFanouts(&node);
}
}
void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
auto result = nodes_.emplace(node->name(), node);
// Check that the graph doesn't contain multiple nodes with the same name.
CHECK(result.second) << "Non unique node name detected: " << node->name();
}
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
input.node = node;
if (fanin.port_id < 0) {
input.port_id = -1;
} else {
input.port_id = i;
num_regular_outputs_[fanin.node] =
std::max(num_regular_outputs_[fanin.node], fanin.port_id);
}
fanouts_[fanin].insert(input);
}
}
NodeDef* GraphView::GetNode(const string& node_name) const {
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
return nullptr;
}
return it->second;
}
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
int port_id) const {
InputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id input ports
result.port_id = port_id;
return result;
}
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
int port_id) const {
OutputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id output ports
result.port_id = port_id;
return result;
}
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
GraphView::GetFanout(const GraphView::OutputPort& port) const {
auto it = fanouts_.find(port);
if (it == fanouts_.end()) {
return empty_set_;
}
return it->second;
}
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
GraphView::GetFanin(const GraphView::InputPort& port) const {
std::unordered_set<GraphView::OutputPort, GraphView::HashPort> result;
if (port.port_id >= 0) {
result.insert(GetRegularFanin(port));
} else {
for (int i = port.node->input_size() - 1; i >= 0; --i) {
OutputPort fanin;
string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id);
if (fanin.port_id < 0) {
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.node = it->second;
result.insert(fanin);
}
} else {
break;
}
}
}
return result;
}
const GraphView::OutputPort GraphView::GetRegularFanin(
const GraphView::InputPort& port) const {
CHECK_LE(0, port.port_id);
OutputPort fanin;
string fanin_name =
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
auto it = nodes_.find(fanin_name);
if (it == nodes_.end()) {
fanin.node = nullptr;
} else {
fanin.node = it->second;
}
return fanin;
}
std::unordered_set<GraphView::InputPort, GraphView::HashPort>
GraphView::GetFanouts(const NodeDef& node,
bool include_controlled_nodes) const {
std::unordered_set<InputPort, HashPort> result;
OutputPort port;
port.node = const_cast<NodeDef*>(&node);
const int first_port_id = include_controlled_nodes ? -1 : 0;
auto it = num_regular_outputs_.find(&node);
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
result.insert(it->second.begin(), it->second.end());
}
}
return result;
}
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
GraphView::GetFanins(const NodeDef& node,
bool include_controlling_nodes) const {
std::unordered_set<OutputPort, HashPort> result;
for (int i = 0; i < node.input_size(); ++i) {
OutputPort fanin;
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
if (fanin.port_id < 0) {
if (!include_controlling_nodes) {
break;
}
}
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.node = it->second;
result.insert(fanin);
}
}
return result;
}
int GraphView::NumFanins(const NodeDef& node,
bool include_controlling_nodes) const {
int count = 0;
for (const string& input : node.input()) {
if (!include_controlling_nodes && IsControlInput(input)) {
break;
}
count += 1;
}
return count;
}
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
GraphView::GetFanoutEdges(const NodeDef& node,
bool include_controlled_edges) const {
std::unordered_set<Edge, HashEdge> result;
OutputPort port;
port.node = const_cast<NodeDef*>(&node);
const int first_port_id = include_controlled_edges ? -1 : 0;
auto it = num_regular_outputs_.find(&node);
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
Edge fanout;
fanout.src.node = const_cast<NodeDef*>(&node);
fanout.src.port_id = i;
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
fanout.tgt = *itr;
result.insert(fanout);
}
}
}
return result;
}
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
GraphView::GetFaninEdges(const NodeDef& node,
bool include_controlling_edges) const {
std::unordered_set<Edge, HashEdge> result;
for (int i = 0; i < node.input_size(); ++i) {
Edge fanin;
fanin.tgt.node = const_cast<NodeDef*>(&node);
fanin.tgt.port_id = i;
string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
if (fanin.src.port_id < 0) {
if (!include_controlling_edges) {
break;
}
}
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.src.node = it->second;
result.insert(fanin);
}
}
return result;
}
} // end namespace grappler } // end namespace grappler
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -18,9 +18,16 @@ limitations under the License.
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
namespace tensorflow { namespace tensorflow {
@ -36,114 +43,290 @@ namespace grappler {
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef. namespace internal {
class GraphView {
// GraphViewInternal is a helper class to simplify graph traversal. It creates
// an immutable view of the nodes and edges represented by a GraphDef protocol
// buffer.
//
// There are two public classes implementing GraphViewInternal:
//
// - GraphView: constructed from the `const GraphDef` and doesn't allow
// to mutate underlying graph via input/output ports lookup functions (ports
// have const pointers to nodes).
//
// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
// the graph via input/output ports lookup functions (ports have non-const
// pointers to nodes), and also have couple additional functions to
// add/remove/replace nodes in the graph.
//
// --------------------------- !!! WARNING !!! ---------------------------------
// Removing nodes from the graph outside of MutableGraphView will
// lead to segfaults! Guaranteed by absl::string_view!
// -----------------------------------------------------------------------------
//
template <typename GraphDefT, typename NodeDefT>
class GraphViewInternal {
public: public:
struct Port { struct Port {
Port() = default; Port() : node(nullptr), port_id(0) {}
Port(NodeDef* n, int port) : node(n), port_id(port) {} Port(NodeDefT* n, int port) : node(n), port_id(port) {}
// TODO(prazek): ports should keep the constness of GraphView. The only way
// to modify graph through the view should be using MutableGraphView.
NodeDef* node = nullptr;
int port_id = -1;
bool operator==(const Port& other) const { bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id; return node == other.node && port_id == other.port_id;
} }
};
struct InputPort : public Port { template <typename H>
InputPort() = default; friend H AbslHashValue(H h, const Port& p) {
InputPort(NodeDef* n, int port_id) : Port(n, port_id) {} return H::combine(std::move(h), p.node, p.port_id);
InputPort(const NodeDef* n, int port_id) }
: Port(const_cast<NodeDef*>(n), port_id) {}
}; NodeDefT* node;
struct OutputPort : public Port { int port_id;
OutputPort() = default;
OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
}; };
struct HashPort { struct InputPort : public Port {
std::size_t operator()(const Port& port) const { using Port::Port;
return reinterpret_cast<std::size_t>(port.node) + port.port_id; };
}
struct OutputPort : public Port {
using Port::Port;
}; };
struct Edge { struct Edge {
OutputPort src; Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
InputPort tgt;
bool operator==(const Edge& other) const { bool operator==(const Edge& other) const {
return src == other.src && tgt == other.tgt; return src == other.src && dst == other.dst;
} }
};
struct HashEdge { template <typename H>
std::size_t operator()(const Edge& edge) const { friend H AbslHashValue(H h, const Edge& e) {
return HashPort()(edge.src) + HashPort()(edge.tgt); return H::combine(std::move(h), e.src, e.dst);
} }
OutputPort src;
InputPort dst;
}; };
explicit GraphView(GraphDef* graph); GraphDefT* graph() const { return graph_; }
GraphDef* GetGraph() const { return graph_; }
NodeDef* GetNode(const string& node_name) const; // Find a node by name or return `nullptr` if it's not in a graph view.
NodeDefT* GetNode(absl::string_view node_name) const {
return gtl::FindWithDefault(nodes_, node_name, nullptr);
}
// Get the specified input port. Note that the special '-1' port_id can be // Get the specified input port. Note that the special '-1' port_id can be
// used to access the controlling nodes (i.e. the nodes connected to node_name // used to access the controlling nodes (i.e. the nodes connected to node_name
// through an incoming control dependency). // through an incoming control dependency).
InputPort GetInputPort(const string& node_name, int port_id) const; InputPort GetInputPort(absl::string_view node_name, int port_id) const {
return InputPort(GetNode(node_name), port_id);
}
// Get the specified output port. Note that the special '-1' port_id can be // Get the specified output port. Note that the special '-1' port_id can be
// used to access the controlled nodes (i.e. the nodes connected to node_name // used to access the controlled nodes (i.e. the nodes connected to node_name
// through an outgoing control dependency). // through an outgoing control dependency).
OutputPort GetOutputPort(const string& node_name, int port_id) const; OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
return OutputPort(GetNode(node_name), port_id);
}
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin) // Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
// of an output (resp. input) port. // of an output (resp. input) port.
const std::unordered_set<InputPort, HashPort>& GetFanout( const absl::flat_hash_set<InputPort>& GetFanout(
const OutputPort& port) const; const OutputPort& port) const {
std::unordered_set<OutputPort, HashPort> GetFanin( return gtl::FindWithDefault(fanouts_, port, empty_set_);
const InputPort& port) const; }
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
if (port.port_id >= 0) return {GetRegularFanin(port)};
// Collect fanin for the control input.
absl::flat_hash_set<OutputPort> result;
for (int i = port.node->input_size() - 1; i >= 0; --i) {
TensorId tensor_id = ParseTensorName(port.node->input(i));
if (tensor_id.index() >= 0) break; // we reached regular inputs
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
}
return result;
}
// Special case: regular (i.e. non-control) input ports can only have one // Special case: regular (i.e. non-control) input ports can only have one
// fanin. // fanin.
const OutputPort GetRegularFanin(const InputPort& port) const; const OutputPort GetRegularFanin(const InputPort& port) const {
DCHECK_GE(port.port_id, 0);
if (port.port_id < 0) return OutputPort();
// Get all the input (resp. output) ports in the immediate fanout (resp fanin) TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
// of a node. Include the controlling nodes iff include_controlling_nodes is return GetOutputPort(tensor_id.node(), tensor_id.index());
// true. }
std::unordered_set<InputPort, HashPort> GetFanouts(
const NodeDef& node, bool include_controlled_nodes) const; // Get all the input (resp. output) ports in the immediate fanout (resp
std::unordered_set<OutputPort, HashPort> GetFanins( // fanin) of a node. Include the controlling nodes iff
const NodeDef& node, bool include_controlling_nodes) const; // include_controlling_nodes is true.
absl::flat_hash_set<InputPort> GetFanouts(
const NodeDef& node, bool include_controlled_nodes) const {
absl::flat_hash_set<InputPort> result;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
result.insert(it->second.begin(), it->second.end());
}
}
return result;
}
absl::flat_hash_set<OutputPort> GetFanins(
const NodeDef& node, bool include_controlling_nodes) const {
absl::flat_hash_set<OutputPort> result;
for (int i = 0; i < node.input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node.input(i));
if (tensor_id.index() < 0 && !include_controlling_nodes) break;
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
}
return result;
}
// Get the number of ports in the immediate fanin of a node. Count the // Get the number of ports in the immediate fanin of a node. Count the
// controlling nodes iff include_controlling_nodes is true. // controlling nodes iff include_controlling_nodes is true.
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const; int NumFanins(const NodeDef& node, bool include_controlling_nodes) const {
int count = 0;
for (const string& input : node.input()) {
if (!include_controlling_nodes && IsControlInput(input)) {
break;
}
count += 1;
}
return count;
}
// Get all the edge in the immediate fanout (resp fanin) of a node. Include // Get the number of ports in the immediate fanout of a node. Count the
// the control edges iff include_controlling_edges is true. // controlling nodes iff include_controlling_nodes is true.
std::unordered_set<Edge, HashEdge> GetFanoutEdges( int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const {
const NodeDef& node, bool include_controlled_edges) const; int count = 0;
std::unordered_set<Edge, HashEdge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const; OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlling_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) count += it->second.size();
}
return count;
}
// Get all the edges in the immediate fanout (resp fanin) of a node.
// Include the control edges iff include_controlling_edges is true.
absl::flat_hash_set<Edge> GetFanoutEdges(
const NodeDef& node, bool include_controlled_edges) const {
absl::flat_hash_set<Edge> result;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_edges ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
result.emplace(/*src*/ OutputPort(const_cast<NodeDefT*>(&node), i),
/*dst*/ *itr);
}
}
}
return result;
}
absl::flat_hash_set<Edge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const {
absl::flat_hash_set<Edge> result;
for (int i = 0; i < node.input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node.input(i));
if (tensor_id.index() < 0 && !include_controlling_edges) break;
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) {
result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()),
/*dst*/ InputPort(const_cast<NodeDefT*>(&node), i));
}
}
return result;
}
protected: protected:
// Add a new `node` to the graph. explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
void AddUniqueNodeOrDie(NodeDef* node);
// Add fanout to every `node` input.
void AddFanouts(NodeDef* node);
std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
GraphDef* MutableGraph() { return graph_; }
using FanoutsMapType = void AddUniqueNodeOrDie(NodeDefT* node) {
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>, auto result = nodes_.emplace(node->name(), node);
HashPort>; // TODO(ezhulenev): Replace CHECK with factory method returning
FanoutsMapType* MutableFanouts() { return &fanouts_; } // absl::StatusOr (when available).
CHECK(result.second) << "Non unique node name detected: " << node->name();
}
void AddFanouts(NodeDefT* node) {
for (int i = 0; i < node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node->input(i));
OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
if (output.port_id < 0) {
fanouts_[output].emplace(node, -1);
} else {
num_regular_outputs_[output.node] =
std::max(num_regular_outputs_[output.node], output.port_id);
fanouts_[output].emplace(node, i);
}
}
}
// Access to the mutable internal state for MutableGraphView.
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
return &nodes_;
}
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
mutable_fanouts() {
return &fanouts_;
}
private: private:
GraphDef* graph_; GraphDefT* graph_; // must outlive the graph view
std::unordered_map<string, NodeDef*> nodes_; absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
std::unordered_set<InputPort, HashPort> empty_set_; absl::flat_hash_set<InputPort> empty_set_;
FanoutsMapType fanouts_; absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
std::unordered_map<const NodeDef*, int> num_regular_outputs_; std::unordered_map<NodeDefT*, int> num_regular_outputs_;
};
} // namespace internal
// Immutable GraphView that keeps the constness of the GraphDef. If you need to
// mutate the graph or the nodes via the graph view lookup functions, see
// MutableGraphView.
class GraphView
: public internal::GraphViewInternal<const GraphDef, const NodeDef> {
public:
explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
for (const NodeDef& node : graph->node()) AddFanouts(&node);
}
}; };
} // end namespace grappler } // end namespace grappler

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/graph_view.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) {
const NodeDef* add_node = graph.GetNode("AddN"); const NodeDef* add_node = graph.GetNode("AddN");
EXPECT_NE(nullptr, add_node); EXPECT_NE(nullptr, add_node);
string fanouts;
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
strings::StrAppend(&fanouts,
strings::StrCat(fo.node->name(), ":", fo.port_id, " "));
}
EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts);
string fanins; absl::flat_hash_set<string> fanouts;
for (const auto& fi : graph.GetFanins(*add_node, false)) { absl::flat_hash_set<string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
strings::StrAppend(&fanins, for (const auto& fo : graph.GetFanouts(*add_node, false)) {
strings::StrCat(fi.node->name(), ":", fi.port_id, " ")); fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id));
} }
EXPECT_EQ("Square_1:0 Square:0 ", fanins); EXPECT_EQ(graph.NumFanouts(*add_node, false), 2);
EXPECT_EQ(fanouts, expected_fanouts);
absl::flat_hash_set<string> fanins;
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
for (const auto& fi : graph.GetFanins(*add_node, false)) {
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
}
EXPECT_EQ(graph.NumFanins(*add_node, false), 2);
EXPECT_EQ(fanins, expected_fanins);
} }
TEST_F(GraphViewTest, ControlDependencies) { TEST_F(GraphViewTest, ControlDependencies) {

View File

@ -19,8 +19,26 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
const absl::flat_hash_set<MutableGraphView::InputPort>&
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
port.port_id));
}
absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
const GraphView::InputPort& port) const {
return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
port.port_id));
}
const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
const GraphView::InputPort& port) const {
return GetRegularFanin(MutableGraphView::InputPort(
const_cast<NodeDef*>(port.node), port.port_id));
}
NodeDef* MutableGraphView::AddNode(NodeDef&& node) { NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
auto* node_in_graph = GetGraph()->add_node(); auto* node_in_graph = graph()->add_node();
*node_in_graph = std::move(node); *node_in_graph = std::move(node);
AddUniqueNodeOrDie(node_in_graph); AddUniqueNodeOrDie(node_in_graph);
@ -31,7 +49,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node, NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
const int output_port_id) { const int output_port_id) {
auto* node_in_graph = GetGraph()->add_node(); auto* node_in_graph = graph()->add_node();
*node_in_graph = std::move(node); *node_in_graph = std::move(node);
AddUniqueNodeOrDie(node_in_graph); AddUniqueNodeOrDie(node_in_graph);
@ -46,8 +64,7 @@ NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
void MutableGraphView::ReplaceInput(const NodeDef& old_input, void MutableGraphView::ReplaceInput(const NodeDef& old_input,
const NodeDef& new_input, const NodeDef& new_input,
const int output_port_id) { const int output_port_id) {
GraphView::OutputPort output_port = OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
GetOutputPort(old_input.name(), output_port_id);
auto fanout = GetFanout(output_port); auto fanout = GetFanout(output_port);
for (auto& input_port : fanout) { for (auto& input_port : fanout) {
input_port.node->set_input(input_port.port_id, new_input.name()); input_port.node->set_input(input_port.port_id, new_input.name());
@ -57,17 +74,17 @@ void MutableGraphView::ReplaceInput(const NodeDef& old_input,
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) { void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
for (const string& node_name_to_delete : nodes_to_delete) for (const string& node_name_to_delete : nodes_to_delete)
RemoveFanouts(MutableNodes()->at(node_name_to_delete)); RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
for (const string& node_name_to_delete : nodes_to_delete) for (const string& node_name_to_delete : nodes_to_delete)
MutableNodes()->erase(node_name_to_delete); mutable_nodes()->erase(node_name_to_delete);
EraseNodesFromGraph(nodes_to_delete, GetGraph()); EraseNodesFromGraph(nodes_to_delete, graph());
} }
void MutableGraphView::RemoveFanouts(NodeDef* node) { void MutableGraphView::RemoveFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) { for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin; OutputPort fanin;
string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = (*MutableNodes())[fanin_name]; fanin.node = (*mutable_nodes())[fanin_name];
InputPort input; InputPort input;
input.node = node; input.node = node;
@ -76,7 +93,7 @@ void MutableGraphView::RemoveFanouts(NodeDef* node) {
else else
input.port_id = i; input.port_id = i;
(*MutableFanouts())[fanin].erase(input); (*mutable_fanouts())[fanin].erase(input);
} }
} }

View File

@ -24,11 +24,25 @@ namespace grappler {
// A utility class to simplify the traversal of a GraphDef that, unlike // A utility class to simplify the traversal of a GraphDef that, unlike
// GraphView, supports updating the graph. Note that you should not modify the // GraphView, supports updating the graph. Note that you should not modify the
// graph separately, because the view will get out of sync. // graph separately, because the view will get out of sync.
class MutableGraphView : public GraphView {
public:
using GraphView::GraphView;
GraphDef* GetGraph() { return MutableGraph(); } class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
public:
explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node);
}
// Lookup fanouts/fanins using immutable ports.
using GraphViewInternal::GetFanout;
const absl::flat_hash_set<InputPort>& GetFanout(
const GraphView::OutputPort& port) const;
using GraphViewInternal::GetFanin;
absl::flat_hash_set<OutputPort> GetFanin(
const GraphView::InputPort& port) const;
using GraphViewInternal::GetRegularFanin;
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
// Adds a new node to graph and updates the view. // Adds a new node to graph and updates the view.
NodeDef* AddNode(NodeDef&& node); NodeDef* AddNode(NodeDef&& node);

View File

@ -26,7 +26,8 @@ namespace {
bool FindChildWithName(const MutableGraphView& graph, bool FindChildWithName(const MutableGraphView& graph,
const string& output_port_name, const string& output_port_name,
const string& input_name) { const string& input_name) {
GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0); MutableGraphView::OutputPort output_port =
graph.GetOutputPort(output_port_name, 0);
auto fanout = graph.GetFanout(output_port); auto fanout = graph.GetFanout(output_port);
for (auto& input_port : fanout) { for (auto& input_port : fanout) {
if (input_port.node->name() == input_name) return true; if (input_port.node->name() == input_name) return true;
@ -59,10 +60,10 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
GraphDef new_graph = item.graph; GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph); MutableGraphView graph(&new_graph);
GraphView::InputPort input = graph.GetInputPort("AddN", 0); MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
EXPECT_EQ("AddN", input.node->name()); EXPECT_EQ("AddN", input.node->name());
EXPECT_EQ(0, input.port_id); EXPECT_EQ(0, input.port_id);
GraphView::OutputPort fanin = graph.GetRegularFanin(input); MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
EXPECT_EQ("Square", fanin.node->name()); EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id); EXPECT_EQ(0, fanin.port_id);
@ -89,7 +90,7 @@ TEST(MutableGraphViewTest, InsertNodes) {
GraphDef new_graph = item.graph; GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph); MutableGraphView graph(&new_graph);
GraphView::InputPort input = graph.GetInputPort("AddN", 0); MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
NodeDef new_node = *input.node; NodeDef new_node = *input.node;
new_node.set_name("new_node"); new_node.set_name("new_node");

View File

@ -145,8 +145,8 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:functions",
@ -422,8 +422,8 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/clusters:cluster",
@ -625,12 +625,13 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame", "//tensorflow/core/grappler/utils:frame",
"@com_google_absl//absl/container:flat_hash_set",
], ],
) )
@ -663,8 +664,8 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/costs:graph_properties",

View File

@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
const FunctionDef& fused_function, const FunctionDef& fused_function,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef fused_node; NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(),
&fused_node); &fused_node);
fused_node.set_op("FilterDataset"); fused_node.set_op("FilterDataset");

View File

@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper(
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef node; NodeDef node;
node.set_op(kConstOpName); node.set_op(kConstOpName);
SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node); SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
(*node.mutable_attr())["dtype"].set_type(dtype); (*node.mutable_attr())["dtype"].set_type(dtype);
std::unique_ptr<tensorflow::TensorProto> tensor = std::unique_ptr<tensorflow::TensorProto> tensor =
@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper(
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
NodeDef node; NodeDef node;
node.set_op("Placeholder"); node.set_op("Placeholder");
SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
(*node.mutable_attr())["dtype"].set_type(dtype); (*node.mutable_attr())["dtype"].set_type(dtype);
TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
shape->set_unknown_rank(false); shape->set_unknown_rank(false);
@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
if (!name.empty()) { if (!name.empty()) {
node.set_name(string(name)); node.set_name(string(name));
} else { } else {
SetUniqueGraphNodeName(op, graph->GetGraph(), &node); SetUniqueGraphNodeName(op, graph->graph(), &node);
} }
node.set_op(string(op)); node.set_op(string(op));
for (const string& input : inputs) { for (const string& input : inputs) {
@ -228,7 +228,7 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr; if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
return graph.GetRegularFanin(input_port).node; return graph.GetRegularFanin(input_port).node;
} }

View File

@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph); NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph())); EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true); EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
} }
@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph); NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
EXPECT_TRUE( EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph()));
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14); EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
} }
@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph); NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph())); EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14); EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
} }
@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* int_node = AddScalarConstNode<int>(42, &graph); NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph())); EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42); EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
} }
@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph); NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph())); EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42); EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
} }
@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph); NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
EXPECT_TRUE( EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph()));
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello"); EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
} }
@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) {
TEST(GraphUtilsTest, ContainsGraphNodeWithName) { TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
AddNode("A", "OpA", {}, {}, &graph); AddNode("A", "OpA", {}, {}, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph())); EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
graph.DeleteNodes({"A"}); graph.DeleteNodes({"A"});
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph())); EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
} }
TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
TEST(GraphUtilsTest, ContainsNodeWithOp) { TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
AddNode("A", "OpA", {}, {}, &graph); AddNode("A", "OpA", {}, {}, &graph);
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph())); EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
graph.DeleteNodes({"A"}); graph.DeleteNodes({"A"});
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph())); EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
} }
TEST(GraphUtilsTest, FindGraphNodeWithName) { TEST(GraphUtilsTest, FindGraphNodeWithName) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph); AddNode("A", "OpA", {}, {}, &graph);
EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1); EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
graph.DeleteNodes({"A"}); graph.DeleteNodes({"A"});
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
} }
TEST(GraphUtilsTest, FindGraphFunctionWithName) { TEST(GraphUtilsTest, FindGraphFunctionWithName) {
@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
TEST(GraphUtilsTest, FindGraphNodeWithOp) { TEST(GraphUtilsTest, FindGraphNodeWithOp) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph); AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0); EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
graph.DeleteNodes({"B"}); graph.DeleteNodes({"B"});
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1); EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1); EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
} }
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) { TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
GraphDef graph_def; GraphDef graph_def;
MutableGraphView graph(&graph_def); MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1); EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph); AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph); AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph); AddNode("A2", "OpA", {"B"}, {}, &graph);
std::vector<int> result_indices = std::vector<int> result_indices =
FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); FindAllGraphNodesWithOp("OpA", *graph.graph());
EXPECT_EQ(result_indices.size(), 2); EXPECT_EQ(result_indices.size(), 2);
EXPECT_EQ(result_indices.at(0), 0); EXPECT_EQ(result_indices.at(0), 0);
EXPECT_EQ(result_indices.at(1), 2); EXPECT_EQ(result_indices.at(1), 2);
graph.DeleteNodes({"A2"}); graph.DeleteNodes({"A2"});
std::vector<int> result_indices_new = std::vector<int> result_indices_new =
FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); FindAllGraphNodesWithOp("OpA", *graph.graph());
EXPECT_EQ(result_indices_new.size(), 1); EXPECT_EQ(result_indices_new.size(), 1);
EXPECT_EQ(result_indices_new.at(0), 0); EXPECT_EQ(result_indices_new.at(0), 0);
} }

View File

@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
const FunctionDef& stateless_function, const FunctionDef& stateless_function,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef stateless_map; NodeDef stateless_map;
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(),
&stateless_map); &stateless_map);
stateless_map.set_op("MapDataset"); stateless_map.set_op("MapDataset");
@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef random_dataset; NodeDef random_dataset;
random_dataset.set_op("RandomDataset"); random_dataset.set_op("RandomDataset");
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(),
&random_dataset); &random_dataset);
const auto* seed = graph_utils::AddScalarConstNode<int64>( const auto* seed = graph_utils::AddScalarConstNode<int64>(
@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
NodeDef batch_dataset; NodeDef batch_dataset;
batch_dataset.set_op("BatchDatasetV2"); batch_dataset.set_op("BatchDatasetV2");
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(),
&batch_dataset); &batch_dataset);
const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph); const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef zip_node; NodeDef zip_node;
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(),
&zip_node); &zip_node);
zip_node.set_op("ZipDataset"); zip_node.set_op("ZipDataset");

View File

@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node; NodeDef new_node;
new_node.set_op(kInsertOpName); new_node.set_op(kInsertOpName);
graph_utils::SetUniqueGraphNodeName( graph_utils::SetUniqueGraphNodeName(
strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(), strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node);
&new_node);
// Set the input of LatencyDataset node as `node` // Set the input of LatencyDataset node as `node`
new_node.add_input(node.name()); new_node.add_input(node.name());
@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
// node corresponds to a `Dataset` op. // node corresponds to a `Dataset` op.
continue; continue;
} }
GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0); MutableGraphView::OutputPort output_port =
graph.GetOutputPort(node.name(), 0);
auto fanout = graph.GetFanout(output_port); auto fanout = graph.GetFanout(output_port);
if (fanout.size() > 1) { if (fanout.size() > 1) {
LOG(WARNING) << node.name() << " has fanout size " << fanout.size(); LOG(WARNING) << node.name() << " has fanout size " << fanout.size();

View File

@ -29,7 +29,7 @@ namespace {
NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) { NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef numa_aware_node = node; NodeDef numa_aware_node = node;
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(),
&numa_aware_node); &numa_aware_node);
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset"); numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
return numa_aware_node; return numa_aware_node;

View File

@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef new_node; NodeDef new_node;
new_node.set_op(kFusedOpName); new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(), graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node);
&new_node);
// Set the `input` input argument. // Set the `input` input argument.
new_node.add_input(map_node.input(0)); new_node.add_input(map_node.input(0));

View File

@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) {
GraphDef output; GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
} }
} // namespace } // namespace

View File

@ -37,8 +37,7 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
const FunctionDef& fused_function, const FunctionDef& fused_function,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef fused_node; NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
&fused_node);
fused_node.set_op("MapDataset"); fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0)); fused_node.add_input(map_node.input(0));
@ -72,8 +71,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node,
const NodeDef& filter_node, const NodeDef& filter_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef filter_by_component; NodeDef filter_by_component;
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(),
graph->GetGraph(), &filter_by_component); &filter_by_component);
filter_by_component.set_op("FilterByLastComponentDataset"); filter_by_component.set_op("FilterByLastComponentDataset");
filter_by_component.add_input(fused_map_node.name()); filter_by_component.add_input(fused_map_node.name());

View File

@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
const FunctionDef& fused_function, const FunctionDef& fused_function,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef fused_node; NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
&fused_node);
fused_node.set_op("MapDataset"); fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0)); fused_node.add_input(parent_map_node.input(0));

View File

@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function,
NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
NodeDef parallel_map = map_node; NodeDef parallel_map = map_node;
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(),
&parallel_map); &parallel_map);
parallel_map.set_op("ParallelMapDataset"); parallel_map.set_op("ParallelMapDataset");
// TODO(b/114475558): We want to set `num_parallel_calls` to a special value, // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,

View File

@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef batch_node; NodeDef batch_node;
batch_node.set_op(old_batch_node.op()); batch_node.set_op(old_batch_node.op());
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(), graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(),
&batch_node); &batch_node);
// Set the `input_dataset` input argument // Set the `input_dataset` input argument
@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
MutableGraphView* graph) { MutableGraphView* graph) {
NodeDef map_node; NodeDef map_node;
map_node.set_op(old_map_node.op()); map_node.set_op(old_map_node.op());
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(), graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node);
&map_node);
// Set the `input_dataset` input argument // Set the `input_dataset` input argument
map_node.add_input(new_batch_node.name()); map_node.add_input(new_batch_node.name());

View File

@ -30,7 +30,7 @@ namespace tensorflow {
namespace grappler { namespace grappler {
namespace { namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
if (take_node.op() != "TakeDataset") return false; if (take_node.op() != "TakeDataset") return false;
const auto& count_node = *graph.GetNode(take_node.input(1)); const auto& count_node = *graph.GetNode(take_node.input(1));
@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) {
return node.attr().at("value").tensor().int64_val(0) == value; return node.attr().at("value").tensor().int64_val(0) == value;
} }
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) {
if (skip_node.op() != "SkipDataset") return false; if (skip_node.op() != "SkipDataset") return false;
// We are looking only for skip(0) nodes. // We are looking only for skip(0) nodes.
return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
} }
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false; if (repeat_node.op() != "RepeatDataset") return false;
// We are looking only for repeat(1) nodes. // We are looking only for repeat(1) nodes.
return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
} }
bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) { bool IsPrefetchZero(const NodeDef& prefetch_node,
const MutableGraphView& graph) {
if (prefetch_node.op() != "PrefetchDataset") return false; if (prefetch_node.op() != "PrefetchDataset") return false;
// We are looking only for prefetch(0) nodes. // We are looking only for prefetch(0) nodes.
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0); return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
} }
bool IsNoOp(const NodeDef& node, const GraphView& graph) { bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
return IsTakeAll(node, graph) || IsSkipNone(node, graph) || return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph); IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
} }

View File

@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
GraphDef output; GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output)); EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
} }
} // namespace } // namespace

View File

@ -31,8 +31,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/functions.h"
@ -219,8 +219,7 @@ class FunctionOptimizerContext {
: grappler_item_id_(item.id), : grappler_item_id_(item.id),
graph_version_(item.graph.versions().producer()), graph_version_(item.graph.versions().producer()),
function_library_(OpRegistry::Global(), item.graph.library()), function_library_(OpRegistry::Global(), item.graph.library()),
// GraphView doesn't not modify the graph or the nodes. graph_view_(&item.graph) {
graph_view_(const_cast<GraphDef*>(&item.graph)) {
InitializeTrulyConstNodes(item); InitializeTrulyConstNodes(item);
InitializeInlinedFunctions(opt_level, item); InitializeInlinedFunctions(opt_level, item);
InitializeFetchNodes(item); InitializeFetchNodes(item);
@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Function specialization might change the number of function outputs, so we // Function specialization might change the number of function outputs, so we
// have to process the final optimized graph and update all the node mapping. // have to process the final optimized graph and update all the node mapping.
if (ctx.RequiresOutputMapping()) { if (ctx.RequiresOutputMapping()) {
GraphView optimized_graph_view(optimized_graph); MutableGraphView optimized_graph_view(optimized_graph);
for (const auto& output_mapping : ctx.output_mappings()) { for (const auto& output_mapping : ctx.output_mappings()) {
const auto& node_name = output_mapping.first; const auto& node_name = output_mapping.first;
const auto& mappings = output_mapping.second; const auto& mappings = output_mapping.second;
@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
int to = mapping.second; int to = mapping.second;
// Get the output port corresponding to the old output position. // Get the output port corresponding to the old output position.
GraphView::OutputPort from_port = MutableGraphView::OutputPort from_port =
optimized_graph_view.GetOutputPort(node_name, from); optimized_graph_view.GetOutputPort(node_name, from);
// Update all input ports that read from old output port. // Update all input ports that read from old output port.
for (GraphView::InputPort to_port : for (MutableGraphView::InputPort to_port :
optimized_graph_view.GetFanout(from_port)) { optimized_graph_view.GetFanout(from_port)) {
*to_port.node->mutable_input(to_port.port_id) = *to_port.node->mutable_input(to_port.port_id) =
strings::StrCat(node_name, ":", to); strings::StrCat(node_name, ":", to);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
@ -29,8 +30,8 @@ limitations under the License.
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h" #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
return Status::OK(); return Status::OK();
} }
Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node, Status CheckForDeadFanout(const MutableGraphView& view,
const NodeMap& node_map, DeviceBase* cpu_device, const NodeDef& switch_node, const NodeMap& node_map,
ResourceMgr* resource_mgr, bool* has_dead_fanout, DeviceBase* cpu_device, ResourceMgr* resource_mgr,
int* dead_fanout) { bool* has_dead_fanout, int* dead_fanout) {
*has_dead_fanout = false; *has_dead_fanout = false;
GraphView::InputPort switch_loopcond_port(&switch_node, 1); GraphView::InputPort switch_loopcond_port(&switch_node, 1);
NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node; const NodeDef* switch_predicate =
view.GetRegularFanin(switch_loopcond_port).node;
// CASE 1: Control is a constant. // CASE 1: Control is a constant.
if (IsConstant(*switch_predicate)) { if (IsConstant(*switch_predicate)) {
@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
} }
GraphView::InputPort switch_input_port(&switch_node, 0); GraphView::InputPort switch_input_port(&switch_node, 0);
NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node; const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
// CASE 2: Zero-iteration while loop. // CASE 2: Zero-iteration while loop.
// We check if its a while loop such that the condition is a simple binary // We check if its a while loop such that the condition is a simple binary
@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches(
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs; std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
// TODO(bsteiner): also rewrite switches as identity. For now we just record // TODO(bsteiner): also rewrite switches as identity. For now we just record
// them // them
std::unordered_set<GraphView::OutputPort, GraphView::HashPort> absl::flat_hash_set<GraphView::OutputPort> identity_switches;
identity_switches;
GraphView view(optimized_graph); MutableGraphView view(optimized_graph);
for (const NodeDef& node : optimized_graph->node()) { for (const NodeDef& node : optimized_graph->node()) {
if (!IsSwitch(node)) { if (!IsSwitch(node)) {
continue; continue;
@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches(
if (!has_dead_fanout) { if (!has_dead_fanout) {
continue; continue;
} }
GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout); GraphView::OutputPort dead(&node, dead_fanout);
identity_switches.insert(dead); identity_switches.insert(dead);
SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs; SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
for (const GraphView::InputPort& port : view.GetFanout(dead)) { zombie_inputs;
for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
if (dead_nodes.find(port.node) == dead_nodes.end()) { if (dead_nodes.find(port.node) == dead_nodes.end()) {
zombie_inputs.PushBack(port); zombie_inputs.PushBack(port);
} }
@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches(
dead_merge_inputs; dead_merge_inputs;
bool found_node_to_preserve = false; bool found_node_to_preserve = false;
while (!found_node_to_preserve && !zombie_inputs.Empty()) { while (!found_node_to_preserve && !zombie_inputs.Empty()) {
GraphView::InputPort dead = zombie_inputs.PopBack(); MutableGraphView::InputPort dead = zombie_inputs.PopBack();
if (nodes_to_preserve.find(dead.node->name()) != if (nodes_to_preserve.find(dead.node->name()) !=
nodes_to_preserve.end()) { nodes_to_preserve.end()) {
found_node_to_preserve = true; found_node_to_preserve = true;
@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches(
found_node_to_preserve = true; found_node_to_preserve = true;
break; break;
} }
GraphView::OutputPort value_index(dead.node, 1); MutableGraphView::OutputPort value_index(dead.node, 1);
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
index_fanout = view.GetFanout(value_index); view.GetFanout(value_index);
if (!index_fanout.empty()) { if (!index_fanout.empty()) {
// The 2nd output (that indicates which input is propagated) is // The 2nd output (that indicates which input is propagated) is
// connected. This never happens in practice, so we'll just skip this // connected. This never happens in practice, so we'll just skip this
@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches(
} }
if (fully_dead) { if (fully_dead) {
local_dead_nodes.insert(dead.node); local_dead_nodes.insert(dead.node);
for (const GraphView::InputPort& port : for (const MutableGraphView::InputPort& port :
view.GetFanouts(*dead.node, true)) { view.GetFanouts(*dead.node, true)) {
zombie_inputs.PushBack(port); zombie_inputs.PushBack(port);
} }
@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches(
break; break;
} else { } else {
if (local_dead_nodes.insert(dead.node).second) { if (local_dead_nodes.insert(dead.node).second) {
for (const GraphView::InputPort& dead_fanout : for (const MutableGraphView::InputPort& dead_fanout :
view.GetFanouts(*dead.node, true)) { view.GetFanouts(*dead.node, true)) {
zombie_inputs.PushBack(dead_fanout); zombie_inputs.PushBack(dead_fanout);
} }

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_memory.h" #include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
#include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h"
@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
// Look for AddN nodes (and equivalent) and record input names. // Look for AddN nodes (and equivalent) and record input names.
GraphView view(&item->graph); MutableGraphView view(&item->graph);
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list; std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
for (NodeDef& node : *item->graph.mutable_node()) { for (NodeDef& node : *item->graph.mutable_node()) {
@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
for (int i = 0; i < node->input_size(); ++i) { for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i); const string& input = node->input(i);
const string node_name = NodeName(input); const string node_name = NodeName(input);
NodeDef* node = view.GetNode(node_name); const NodeDef* node = view.GetNode(node_name);
input_topo_index.push_back(topo_order.at(node)); input_topo_index.push_back(topo_order.at(node));
} }
int min_input_topo_index = INT_MAX; int min_input_topo_index = INT_MAX;
@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger(
return nullptr; return nullptr;
} }
static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) { static bool IsSwappable(const MutableGraphView& graph,
MutableGraphView::OutputPort output) {
const NodeDef& node = *output.node; const NodeDef& node = *output.node;
// There is no point in swapping out persistent tensors, since the tensor will // There is no point in swapping out persistent tensors, since the tensor will
// continue to use memory. // continue to use memory.
@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
// If placed on the same device, these nodes are just forwarding references // If placed on the same device, these nodes are just forwarding references
// to their input. Therefore they are swappable iff their fanin is swappable // to their input. Therefore they are swappable iff their fanin is swappable
// or it resides on a different device. // or it resides on a different device.
GraphView::InputPort input; MutableGraphView::InputPort input;
input.node = output.node; input.node = output.node;
input.port_id = 0; input.port_id = 0;
GraphView::OutputPort fanin = graph.GetRegularFanin(input); MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
if (fanin.node->device() == node.device()) { if (fanin.node->device() == node.device()) {
return IsSwappable(graph, fanin); return IsSwappable(graph, fanin);
} }
@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
} }
static NodeDef* FindSwapOutTrigger( static NodeDef* FindSwapOutTrigger(
const NodeDef* node, int input_id, const GraphView& view, const NodeDef* node, int input_id, const MutableGraphView& view,
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>& const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
execution_times) { execution_times) {
// Find the output port that generated the tensor to swap. // Find the output port that generated the tensor to swap.
GraphView::InputPort swap; MutableGraphView::InputPort swap;
swap.node = const_cast<NodeDef*>(node); swap.node = const_cast<NodeDef*>(node);
swap.port_id = input_id; swap.port_id = input_id;
GraphView::OutputPort generator = view.GetRegularFanin(swap); MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
if (!generator.node) { if (!generator.node) {
return nullptr; return nullptr;
} }
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout = const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
view.GetFanout(generator); view.GetFanout(generator);
NodeDef* trigger = nullptr; NodeDef* trigger = nullptr;
Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity()); Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger(
return trigger; return trigger;
} }
static bool IsSwappable(GraphView::InputPort input) { static bool IsSwappable(MutableGraphView::InputPort input) {
const NodeDef& node = *input.node; const NodeDef& node = *input.node;
const OpDef* op_def; const OpDef* op_def;
@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) {
} }
struct MemInfo { struct MemInfo {
GraphView::OutputPort port; MutableGraphView::OutputPort port;
int64 memory_used; int64 memory_used;
std::vector<GraphView::InputPort> uses_left; std::vector<MutableGraphView::InputPort> uses_left;
double fitness; double fitness;
bool operator<(const MemInfo& other) const { return fitness < other.fitness; } bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates(
std::vector<MemInfo> mem_state; std::vector<MemInfo> mem_state;
GraphView graph(&item->graph); MutableGraphView graph(&item->graph);
for (const auto& live_tensor : mem_usage.live_tensors) { for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.memory_used <= 1024) { if (live_tensor.memory_used <= 1024) {
// Don't bother with small tensors. // Don't bother with small tensors.
@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates(
if (skip_list->find(live_tensor.node) != skip_list->end()) { if (skip_list->find(live_tensor.node) != skip_list->end()) {
continue; continue;
} }
GraphView::OutputPort port = MutableGraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id); graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
if (!IsSwappable(graph, port)) { if (!IsSwappable(graph, port)) {
continue; continue;
@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates(
Costs::Duration allocation_time = live_tensor.allocation_time; Costs::Duration allocation_time = live_tensor.allocation_time;
Costs::Duration earliest_use(Costs::Duration::infinity()); Costs::Duration earliest_use(Costs::Duration::infinity());
bool valid = true; bool valid = true;
for (GraphView::InputPort input : graph.GetFanout(port)) { for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
// Get execution time. // Get execution time.
auto it = op_completion_times.find(input.node->name()); auto it = op_completion_times.find(input.node->name());
if (it == op_completion_times.end()) { if (it == op_completion_times.end()) {
@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates(
// the values do not fit into any integral type. // the values do not fit into any integral type.
mem_info.fitness = mem_info.fitness =
MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) / MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) + MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
MathUtil::IPow<double>((allocation_time - peak_time).count(), 2); MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness; mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info); mem_state.push_back(mem_info);
@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates(
std::sort(mem_state.begin(), mem_state.end()); std::sort(mem_state.begin(), mem_state.end());
for (const MemInfo& mem_info : mem_state) { for (const MemInfo& mem_info : mem_state) {
for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) { for (const MutableGraphView::InputPort fanout_to_swap :
mem_info.uses_left) {
VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":" VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
<< fanout_to_swap.port_id << " of tensor " << fanout_to_swap.port_id << " of tensor "
<< mem_info.port.node->name() << ":" << mem_info.port.port_id << mem_info.port.node->name() << ":" << mem_info.port.port_id
@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
for (const auto& node : item->graph.node()) { for (const auto& node : item->graph.node()) {
name_map[node.name()] = &node; name_map[node.name()] = &node;
} }
GraphView view(&item->graph); MutableGraphView view(&item->graph);
bool updated_graph = false; bool updated_graph = false;

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphProperties properties(item); GraphProperties properties(item);
bool inferred_properties = false; bool inferred_properties = false;
GraphView graph(optimized_graph); MutableGraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more // The product of all the dimensions in a tensor shape can be expressed more
// simply as the size of the tensor. // simply as the size of the tensor.
@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsShape(node)) { if (!IsShape(node)) {
continue; continue;
} }
for (GraphView::InputPort fanout : for (MutableGraphView::InputPort fanout :
graph.GetFanout(GraphView::OutputPort(&node, 0))) { graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
if (fanout.node->op() != "Prod") { if (fanout.node->op() != "Prod") {
continue; continue;
} }
@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// rewrite the whole expression directly as a Size operation. // rewrite the whole expression directly as a Size operation.
continue; continue;
} }
const GraphView::OutputPort reduce_indices = const MutableGraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
if (!inferred_properties) { if (!inferred_properties) {
// Infer properties lazily in case they are not needed. // Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false)); TF_RETURN_IF_ERROR(properties.InferStatically(false));
@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// is possible whenever the symbolic dimensions in the numerator and // is possible whenever the symbolic dimensions in the numerator and
// denominator cancel each other. // denominator cancel each other.
if (node.op() == "Div") { if (node.op() == "Div") {
const GraphView::OutputPort input1 = const MutableGraphView::OutputPort input1 =
graph.GetRegularFanin(GraphView::InputPort(&node, 0)); graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
const GraphView::OutputPort input2 = const MutableGraphView::OutputPort input2 =
graph.GetRegularFanin(GraphView::InputPort(&node, 1)); graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
if (!IsSize(*input1.node) || !IsSize(*input2.node)) { if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue; continue;
} }

View File

@ -101,6 +101,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:mutable_graph_view",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
], ],
) )

View File

@ -21,8 +21,11 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
void ReverseDfs( namespace {
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
template <typename GraphViewType>
void ReverseDfsInternal(
const GraphViewType& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order, const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order, const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) { const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
@ -79,5 +82,25 @@ void ReverseDfs(
} }
} }
} // namespace
void ReverseDfs(
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
ReverseDfsInternal<GraphView>(graph_view, from, pre_order, post_order,
on_back_edge);
}
void ReverseDfs(
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
ReverseDfsInternal<MutableGraphView>(graph_view, from, pre_order, post_order,
on_back_edge);
}
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional> #include <functional>
#include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
@ -34,6 +35,12 @@ void ReverseDfs(
const std::function<void(const NodeDef*)>& post_order, const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge); const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
void ReverseDfs(
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,9 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/utils/traversal.h" #include "tensorflow/core/grappler/utils/traversal.h"
//#include "tensorflow/core/framework/node_def.pb.h"
//#include "tensorflow/core/lib/core/status_test_util.h"
//#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) {
found_back_edge = true; found_back_edge = true;
}); });
EXPECT_EQ(std::vector<string>({"1", "4", "3", "2", "5", "0"}), pre_order); // Pre/Post order traversals are non deterministic because a node fanin is an
EXPECT_EQ(std::vector<string>({"4", "5", "2", "3", "1", "0"}), post_order); // absl::flat_hash_set with non deterministic traversal order.
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
std::set<ValidTraversal> valid_traversals = {
// pre_order post_order
{{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}},
{{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}};
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
EXPECT_FALSE(found_back_edge); EXPECT_FALSE(found_back_edge);
} }
@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) {
back_edges.push_back(strings::StrCat(src->name(), "->", dst->name())); back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
}); });
EXPECT_EQ(std::vector<string>({"6", "3", "2", "1", "5", "4"}), pre_order); // Pre/Post order traversals are non deterministic because a node fanin is an
EXPECT_EQ(std::vector<string>({"1", "4", "5", "2", "3", "6"}), post_order); // absl::flat_hash_set with non deterministic traversal order.
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
std::set<ValidTraversal> valid_traversals = {
// pre_order post_order
{{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}},
{{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}},
{{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}};
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
EXPECT_EQ(std::vector<string>({"4->3"}), back_edges); EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
} }