[Grappler] Preserve constness of the GraphDef in GraphView.

1. Split GrapView into GraphView and MutableGraphView with separate {Input/Output}Port types with different node pointer constness.

2. Properly use GraphView and MutableGraphView in graph properties, and get rid of const_cast.

3. Remove const_cast in function optimizer.

4. Migrate GraphView to absl containers and hash

PiperOrigin-RevId: 219488040
This commit is contained in:
Eugene Zhulenev 2018-10-31 09:42:35 -07:00 committed by TensorFlower Gardener
parent 92e604060a
commit 3eeaf9f1e1
32 changed files with 503 additions and 445 deletions

View File

@ -69,6 +69,9 @@ cc_library(
":utils",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
],
)
@ -82,6 +85,8 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)

View File

@ -44,7 +44,7 @@ cc_library(
"@com_google_absl//absl/memory",
"//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",

View File

@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
@ -456,10 +456,10 @@ class SymbolicShapeRefiner {
const GraphView& graph,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
: graph_(graph),
function_library_(OpRegistry::Global(), graph.GetGraph()->library()),
function_library_(OpRegistry::Global(), graph.graph()->library()),
fed_ports_(fed_ports) {
graph_def_version_ = graph.GetGraph()->versions().producer();
node_to_context_.reserve(graph.GetGraph()->node_size());
graph_def_version_ = graph.graph()->versions().producer();
node_to_context_.reserve(graph.graph()->node_size());
}
const GraphView& graph() const { return graph_; }
@ -512,7 +512,7 @@ class SymbolicShapeRefiner {
// Placeholder with Const) don't affect one in
// fun_to_grappler_function_item_.
GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
MutableGraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
@ -532,7 +532,7 @@ class SymbolicShapeRefiner {
"Function inputs should not contain control nodes.");
}
NodeDef* input_node = graph_.GetNode(node_name);
const NodeDef* input_node = graph_.GetNode(node_name);
if (input_node == nullptr) {
return errors::FailedPrecondition(node_name,
" was not found in the graph.");
@ -566,7 +566,7 @@ class SymbolicShapeRefiner {
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
const string& input = function_node->input(i);
const string& node_name = NodeName(input);
NodeDef* input_node = graph_.GetNode(node_name);
const NodeDef* input_node = graph_.GetNode(node_name);
if (IsConstant(*input_node)) {
TF_CHECK_OK(
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
@ -1441,8 +1441,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
continue;
}
ShapeHandle input = in->output(fanin.src.port_id);
CHECK_EQ(fanin.tgt.node, node);
c->SetInput(fanin.tgt.port_id, input);
CHECK_EQ(fanin.dst.node, node);
c->SetInput(fanin.dst.port_id, input);
if (!out_initialized) {
out_initialized = true;
out = input;
@ -1673,7 +1673,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
}
}
GraphView graph_view(const_cast<GraphDef*>(&item_.graph));
GraphView graph_view(&item_.graph);
// List the resources and the nodes using them. Also collect the Merge nodes,
// fed nodes, and primary inputs.
@ -1725,10 +1725,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
for (const auto& resource : resources) {
for (const NodeDef* src : resource.second.first) {
resource_handles[src] = resource.first;
for (const NodeDef* tgt : resource.second.second) {
for (const NodeDef* dst : resource.second.second) {
// Add control edges from enqueue to dequeue nodes to ensure they are
// processed in their logical order.
extra_deps.emplace_back(src, tgt);
extra_deps.emplace_back(src, dst);
}
}
}

View File

@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
return OpPortIdToArgId(node, op.input_arg(), port_id);
}
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
AddUniqueNodeOrDie(node);
}
for (NodeDef& node : *graph_->mutable_node()) {
AddFanouts(&node);
}
}
void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
auto result = nodes_.emplace(node->name(), node);
// Check that the graph doesn't contain multiple nodes with the same name.
CHECK(result.second) << "Non unique node name detected: " << node->name();
}
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
input.node = node;
if (fanin.port_id < 0) {
input.port_id = -1;
} else {
input.port_id = i;
num_regular_outputs_[fanin.node] =
std::max(num_regular_outputs_[fanin.node], fanin.port_id);
}
fanouts_[fanin].insert(input);
}
}
NodeDef* GraphView::GetNode(const string& node_name) const {
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
return nullptr;
}
return it->second;
}
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
int port_id) const {
InputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id input ports
result.port_id = port_id;
return result;
}
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
int port_id) const {
OutputPort result;
result.node = GetNode(node_name);
// TODO(bsteiner): verify that the node has at least port_id output ports
result.port_id = port_id;
return result;
}
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
GraphView::GetFanout(const GraphView::OutputPort& port) const {
auto it = fanouts_.find(port);
if (it == fanouts_.end()) {
return empty_set_;
}
return it->second;
}
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
GraphView::GetFanin(const GraphView::InputPort& port) const {
std::unordered_set<GraphView::OutputPort, GraphView::HashPort> result;
if (port.port_id >= 0) {
result.insert(GetRegularFanin(port));
} else {
for (int i = port.node->input_size() - 1; i >= 0; --i) {
OutputPort fanin;
string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id);
if (fanin.port_id < 0) {
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.node = it->second;
result.insert(fanin);
}
} else {
break;
}
}
}
return result;
}
const GraphView::OutputPort GraphView::GetRegularFanin(
const GraphView::InputPort& port) const {
CHECK_LE(0, port.port_id);
OutputPort fanin;
string fanin_name =
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
auto it = nodes_.find(fanin_name);
if (it == nodes_.end()) {
fanin.node = nullptr;
} else {
fanin.node = it->second;
}
return fanin;
}
std::unordered_set<GraphView::InputPort, GraphView::HashPort>
GraphView::GetFanouts(const NodeDef& node,
bool include_controlled_nodes) const {
std::unordered_set<InputPort, HashPort> result;
OutputPort port;
port.node = const_cast<NodeDef*>(&node);
const int first_port_id = include_controlled_nodes ? -1 : 0;
auto it = num_regular_outputs_.find(&node);
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
result.insert(it->second.begin(), it->second.end());
}
}
return result;
}
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
GraphView::GetFanins(const NodeDef& node,
bool include_controlling_nodes) const {
std::unordered_set<OutputPort, HashPort> result;
for (int i = 0; i < node.input_size(); ++i) {
OutputPort fanin;
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
if (fanin.port_id < 0) {
if (!include_controlling_nodes) {
break;
}
}
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.node = it->second;
result.insert(fanin);
}
}
return result;
}
int GraphView::NumFanins(const NodeDef& node,
bool include_controlling_nodes) const {
int count = 0;
for (const string& input : node.input()) {
if (!include_controlling_nodes && IsControlInput(input)) {
break;
}
count += 1;
}
return count;
}
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
GraphView::GetFanoutEdges(const NodeDef& node,
bool include_controlled_edges) const {
std::unordered_set<Edge, HashEdge> result;
OutputPort port;
port.node = const_cast<NodeDef*>(&node);
const int first_port_id = include_controlled_edges ? -1 : 0;
auto it = num_regular_outputs_.find(&node);
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
Edge fanout;
fanout.src.node = const_cast<NodeDef*>(&node);
fanout.src.port_id = i;
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
fanout.tgt = *itr;
result.insert(fanout);
}
}
}
return result;
}
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
GraphView::GetFaninEdges(const NodeDef& node,
bool include_controlling_edges) const {
std::unordered_set<Edge, HashEdge> result;
for (int i = 0; i < node.input_size(); ++i) {
Edge fanin;
fanin.tgt.node = const_cast<NodeDef*>(&node);
fanin.tgt.port_id = i;
string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
if (fanin.src.port_id < 0) {
if (!include_controlling_edges) {
break;
}
}
auto it = nodes_.find(fanin_name);
if (it != nodes_.end()) {
fanin.src.node = it->second;
result.insert(fanin);
}
}
return result;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -18,9 +18,16 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -36,114 +43,290 @@ namespace grappler {
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
namespace internal {
// GraphViewInternal is a helper class to simplify graph traversal. It creates
// an immutable view of the nodes and edges represented by a GraphDef protocol
// buffer.
//
// There are two public classes implementing GraphViewInternal:
//
// - GraphView: constructed from the `const GraphDef` and doesn't allow
// to mutate underlying graph via input/output ports lookup functions (ports
// have const pointers to nodes).
//
// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
// the graph via input/output ports lookup functions (ports have non-const
// pointers to nodes), and also have couple additional functions to
// add/remove/replace nodes in the graph.
//
// --------------------------- !!! WARNING !!! ---------------------------------
// Removing nodes from the graph outside of MutableGraphView will
// lead to segfaults! Guaranteed by absl::string_view!
// -----------------------------------------------------------------------------
//
template <typename GraphDefT, typename NodeDefT>
class GraphViewInternal {
public:
struct Port {
Port() = default;
Port(NodeDef* n, int port) : node(n), port_id(port) {}
// TODO(prazek): ports should keep the constness of GraphView. The only way
// to modify graph through the view should be using MutableGraphView.
NodeDef* node = nullptr;
int port_id = -1;
Port() : node(nullptr), port_id(0) {}
Port(NodeDefT* n, int port) : node(n), port_id(port) {}
bool operator==(const Port& other) const {
return node == other.node && port_id == other.port_id;
}
};
struct InputPort : public Port {
InputPort() = default;
InputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
InputPort(const NodeDef* n, int port_id)
: Port(const_cast<NodeDef*>(n), port_id) {}
};
struct OutputPort : public Port {
OutputPort() = default;
OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
template <typename H>
friend H AbslHashValue(H h, const Port& p) {
return H::combine(std::move(h), p.node, p.port_id);
}
NodeDefT* node;
int port_id;
};
struct HashPort {
std::size_t operator()(const Port& port) const {
return reinterpret_cast<std::size_t>(port.node) + port.port_id;
}
struct InputPort : public Port {
using Port::Port;
};
struct OutputPort : public Port {
using Port::Port;
};
struct Edge {
OutputPort src;
InputPort tgt;
Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
bool operator==(const Edge& other) const {
return src == other.src && tgt == other.tgt;
return src == other.src && dst == other.dst;
}
};
struct HashEdge {
std::size_t operator()(const Edge& edge) const {
return HashPort()(edge.src) + HashPort()(edge.tgt);
template <typename H>
friend H AbslHashValue(H h, const Edge& e) {
return H::combine(std::move(h), e.src, e.dst);
}
OutputPort src;
InputPort dst;
};
explicit GraphView(GraphDef* graph);
GraphDef* GetGraph() const { return graph_; }
NodeDef* GetNode(const string& node_name) const;
GraphDefT* graph() const { return graph_; }
// Find a node by name or return `nullptr` if it's not in a graph view.
NodeDefT* GetNode(absl::string_view node_name) const {
return gtl::FindWithDefault(nodes_, node_name, nullptr);
}
// Get the specified input port. Note that the special '-1' port_id can be
// used to access the controlling nodes (i.e. the nodes connected to node_name
// through an incoming control dependency).
InputPort GetInputPort(const string& node_name, int port_id) const;
InputPort GetInputPort(absl::string_view node_name, int port_id) const {
return InputPort(GetNode(node_name), port_id);
}
// Get the specified output port. Note that the special '-1' port_id can be
// used to access the controlled nodes (i.e. the nodes connected to node_name
// through an outgoing control dependency).
OutputPort GetOutputPort(const string& node_name, int port_id) const;
OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
return OutputPort(GetNode(node_name), port_id);
}
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
// of an output (resp. input) port.
const std::unordered_set<InputPort, HashPort>& GetFanout(
const OutputPort& port) const;
std::unordered_set<OutputPort, HashPort> GetFanin(
const InputPort& port) const;
const absl::flat_hash_set<InputPort>& GetFanout(
const OutputPort& port) const {
return gtl::FindWithDefault(fanouts_, port, empty_set_);
}
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
if (port.port_id >= 0) return {GetRegularFanin(port)};
// Collect fanin for the control input.
absl::flat_hash_set<OutputPort> result;
for (int i = port.node->input_size() - 1; i >= 0; --i) {
TensorId tensor_id = ParseTensorName(port.node->input(i));
if (tensor_id.index() >= 0) break; // we reached regular inputs
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
}
return result;
}
// Special case: regular (i.e. non-control) input ports can only have one
// fanin.
const OutputPort GetRegularFanin(const InputPort& port) const;
const OutputPort GetRegularFanin(const InputPort& port) const {
DCHECK_GE(port.port_id, 0);
if (port.port_id < 0) return OutputPort();
// Get all the input (resp. output) ports in the immediate fanout (resp fanin)
// of a node. Include the controlling nodes iff include_controlling_nodes is
// true.
std::unordered_set<InputPort, HashPort> GetFanouts(
const NodeDef& node, bool include_controlled_nodes) const;
std::unordered_set<OutputPort, HashPort> GetFanins(
const NodeDef& node, bool include_controlling_nodes) const;
TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
return GetOutputPort(tensor_id.node(), tensor_id.index());
}
// Get all the input (resp. output) ports in the immediate fanout (resp
// fanin) of a node. Include the controlling nodes iff
// include_controlling_nodes is true.
absl::flat_hash_set<InputPort> GetFanouts(
const NodeDef& node, bool include_controlled_nodes) const {
absl::flat_hash_set<InputPort> result;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
result.insert(it->second.begin(), it->second.end());
}
}
return result;
}
absl::flat_hash_set<OutputPort> GetFanins(
const NodeDef& node, bool include_controlling_nodes) const {
absl::flat_hash_set<OutputPort> result;
for (int i = 0; i < node.input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node.input(i));
if (tensor_id.index() < 0 && !include_controlling_nodes) break;
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
}
return result;
}
// Get the number of ports in the immediate fanin of a node. Count the
// controlling nodes iff include_controlling_nodes is true.
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const;
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const {
int count = 0;
for (const string& input : node.input()) {
if (!include_controlling_nodes && IsControlInput(input)) {
break;
}
count += 1;
}
return count;
}
// Get all the edge in the immediate fanout (resp fanin) of a node. Include
// the control edges iff include_controlling_edges is true.
std::unordered_set<Edge, HashEdge> GetFanoutEdges(
const NodeDef& node, bool include_controlled_edges) const;
std::unordered_set<Edge, HashEdge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const;
// Get the number of ports in the immediate fanout of a node. Count the
// controlling nodes iff include_controlling_nodes is true.
int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const {
int count = 0;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlling_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) count += it->second.size();
}
return count;
}
// Get all the edges in the immediate fanout (resp fanin) of a node.
// Include the control edges iff include_controlling_edges is true.
absl::flat_hash_set<Edge> GetFanoutEdges(
const NodeDef& node, bool include_controlled_edges) const {
absl::flat_hash_set<Edge> result;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlled_edges ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
for (int i = first_port_id; i <= last_port_id; ++i) {
port.port_id = i;
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
result.emplace(/*src*/ OutputPort(const_cast<NodeDefT*>(&node), i),
/*dst*/ *itr);
}
}
}
return result;
}
absl::flat_hash_set<Edge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const {
absl::flat_hash_set<Edge> result;
for (int i = 0; i < node.input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node.input(i));
if (tensor_id.index() < 0 && !include_controlling_edges) break;
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) {
result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()),
/*dst*/ InputPort(const_cast<NodeDefT*>(&node), i));
}
}
return result;
}
protected:
// Add a new `node` to the graph.
void AddUniqueNodeOrDie(NodeDef* node);
// Add fanout to every `node` input.
void AddFanouts(NodeDef* node);
std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
GraphDef* MutableGraph() { return graph_; }
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
using FanoutsMapType =
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
HashPort>;
FanoutsMapType* MutableFanouts() { return &fanouts_; }
void AddUniqueNodeOrDie(NodeDefT* node) {
auto result = nodes_.emplace(node->name(), node);
// TODO(ezhulenev): Replace CHECK with factory method returning
// absl::StatusOr (when available).
CHECK(result.second) << "Non unique node name detected: " << node->name();
}
void AddFanouts(NodeDefT* node) {
for (int i = 0; i < node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node->input(i));
OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
if (output.port_id < 0) {
fanouts_[output].emplace(node, -1);
} else {
num_regular_outputs_[output.node] =
std::max(num_regular_outputs_[output.node], output.port_id);
fanouts_[output].emplace(node, i);
}
}
}
// Access to the mutable internal state for MutableGraphView.
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
return &nodes_;
}
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
mutable_fanouts() {
return &fanouts_;
}
private:
GraphDef* graph_;
std::unordered_map<string, NodeDef*> nodes_;
std::unordered_set<InputPort, HashPort> empty_set_;
FanoutsMapType fanouts_;
std::unordered_map<const NodeDef*, int> num_regular_outputs_;
GraphDefT* graph_; // must outlive the graph view
absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
absl::flat_hash_set<InputPort> empty_set_;
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
std::unordered_map<NodeDefT*, int> num_regular_outputs_;
};
} // namespace internal
// Immutable GraphView that keeps the constness of the GraphDef. If you need to
// mutate the graph or the nodes via the graph view lookup functions, see
// MutableGraphView.
class GraphView
: public internal::GraphViewInternal<const GraphDef, const NodeDef> {
public:
explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
for (const NodeDef& node : graph->node()) AddFanouts(&node);
}
};
} // end namespace grappler

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) {
const NodeDef* add_node = graph.GetNode("AddN");
EXPECT_NE(nullptr, add_node);
string fanouts;
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
strings::StrAppend(&fanouts,
strings::StrCat(fo.node->name(), ":", fo.port_id, " "));
}
EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts);
string fanins;
for (const auto& fi : graph.GetFanins(*add_node, false)) {
strings::StrAppend(&fanins,
strings::StrCat(fi.node->name(), ":", fi.port_id, " "));
absl::flat_hash_set<string> fanouts;
absl::flat_hash_set<string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id));
}
EXPECT_EQ("Square_1:0 Square:0 ", fanins);
EXPECT_EQ(graph.NumFanouts(*add_node, false), 2);
EXPECT_EQ(fanouts, expected_fanouts);
absl::flat_hash_set<string> fanins;
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
for (const auto& fi : graph.GetFanins(*add_node, false)) {
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
}
EXPECT_EQ(graph.NumFanins(*add_node, false), 2);
EXPECT_EQ(fanins, expected_fanins);
}
TEST_F(GraphViewTest, ControlDependencies) {

View File

@ -19,8 +19,26 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
const absl::flat_hash_set<MutableGraphView::InputPort>&
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
port.port_id));
}
absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
const GraphView::InputPort& port) const {
return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
port.port_id));
}
const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
const GraphView::InputPort& port) const {
return GetRegularFanin(MutableGraphView::InputPort(
const_cast<NodeDef*>(port.node), port.port_id));
}
NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
auto* node_in_graph = GetGraph()->add_node();
auto* node_in_graph = graph()->add_node();
*node_in_graph = std::move(node);
AddUniqueNodeOrDie(node_in_graph);
@ -31,7 +49,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
const int output_port_id) {
auto* node_in_graph = GetGraph()->add_node();
auto* node_in_graph = graph()->add_node();
*node_in_graph = std::move(node);
AddUniqueNodeOrDie(node_in_graph);
@ -46,8 +64,7 @@ NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
const NodeDef& new_input,
const int output_port_id) {
GraphView::OutputPort output_port =
GetOutputPort(old_input.name(), output_port_id);
OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
auto fanout = GetFanout(output_port);
for (auto& input_port : fanout) {
input_port.node->set_input(input_port.port_id, new_input.name());
@ -57,17 +74,17 @@ void MutableGraphView::ReplaceInput(const NodeDef& old_input,
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
for (const string& node_name_to_delete : nodes_to_delete)
RemoveFanouts(MutableNodes()->at(node_name_to_delete));
RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
for (const string& node_name_to_delete : nodes_to_delete)
MutableNodes()->erase(node_name_to_delete);
EraseNodesFromGraph(nodes_to_delete, GetGraph());
mutable_nodes()->erase(node_name_to_delete);
EraseNodesFromGraph(nodes_to_delete, graph());
}
void MutableGraphView::RemoveFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = (*MutableNodes())[fanin_name];
fanin.node = (*mutable_nodes())[fanin_name];
InputPort input;
input.node = node;
@ -76,7 +93,7 @@ void MutableGraphView::RemoveFanouts(NodeDef* node) {
else
input.port_id = i;
(*MutableFanouts())[fanin].erase(input);
(*mutable_fanouts())[fanin].erase(input);
}
}

View File

@ -24,11 +24,25 @@ namespace grappler {
// A utility class to simplify the traversal of a GraphDef that, unlike
// GraphView, supports updating the graph. Note that you should not modify the
// graph separately, because the view will get out of sync.
class MutableGraphView : public GraphView {
public:
using GraphView::GraphView;
GraphDef* GetGraph() { return MutableGraph(); }
class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
public:
explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node);
}
// Lookup fanouts/fanins using immutable ports.
using GraphViewInternal::GetFanout;
const absl::flat_hash_set<InputPort>& GetFanout(
const GraphView::OutputPort& port) const;
using GraphViewInternal::GetFanin;
absl::flat_hash_set<OutputPort> GetFanin(
const GraphView::InputPort& port) const;
using GraphViewInternal::GetRegularFanin;
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
// Adds a new node to graph and updates the view.
NodeDef* AddNode(NodeDef&& node);

View File

@ -26,7 +26,8 @@ namespace {
bool FindChildWithName(const MutableGraphView& graph,
const string& output_port_name,
const string& input_name) {
GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0);
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(output_port_name, 0);
auto fanout = graph.GetFanout(output_port);
for (auto& input_port : fanout) {
if (input_port.node->name() == input_name) return true;
@ -59,10 +60,10 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph);
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
EXPECT_EQ("AddN", input.node->name());
EXPECT_EQ(0, input.port_id);
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
EXPECT_EQ("Square", fanin.node->name());
EXPECT_EQ(0, fanin.port_id);
@ -89,7 +90,7 @@ TEST(MutableGraphViewTest, InsertNodes) {
GraphDef new_graph = item.graph;
MutableGraphView graph(&new_graph);
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
NodeDef new_node = *input.node;
new_node.set_name("new_node");

View File

@ -145,8 +145,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions",
@ -422,8 +422,8 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
@ -625,12 +625,13 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
"@com_google_absl//absl/container:flat_hash_set",
],
)
@ -663,8 +664,8 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",

View File

@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
const FunctionDef& fused_function,
MutableGraphView* graph) {
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(),
&fused_node);
fused_node.set_op("FilterDataset");

View File

@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper(
MutableGraphView* graph) {
NodeDef node;
node.set_op(kConstOpName);
SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node);
SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
(*node.mutable_attr())["dtype"].set_type(dtype);
std::unique_ptr<tensorflow::TensorProto> tensor =
@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper(
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
NodeDef node;
node.set_op("Placeholder");
SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
(*node.mutable_attr())["dtype"].set_type(dtype);
TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
shape->set_unknown_rank(false);
@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
if (!name.empty()) {
node.set_name(string(name));
} else {
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
SetUniqueGraphNodeName(op, graph->graph(), &node);
}
node.set_op(string(op));
for (const string& input : inputs) {
@ -228,7 +228,7 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
if (node.input_size() == 0) return nullptr;
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
return graph.GetRegularFanin(input_port).node;
}

View File

@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
}
@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
EXPECT_TRUE(
ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
}
@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
}
@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
}
@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
}
@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
EXPECT_TRUE(
ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
}
@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) {
TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
AddNode("A", "OpA", {}, {}, &graph);
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph()));
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
graph.DeleteNodes({"A"});
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
}
TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
AddNode("A", "OpA", {}, {}, &graph);
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph()));
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
graph.DeleteNodes({"A"});
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
}
TEST(GraphUtilsTest, FindGraphNodeWithName) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
graph.DeleteNodes({"A"});
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
}
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
graph.DeleteNodes({"B"});
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
}
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph);
std::vector<int> result_indices =
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
FindAllGraphNodesWithOp("OpA", *graph.graph());
EXPECT_EQ(result_indices.size(), 2);
EXPECT_EQ(result_indices.at(0), 0);
EXPECT_EQ(result_indices.at(1), 2);
graph.DeleteNodes({"A2"});
std::vector<int> result_indices_new =
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
FindAllGraphNodesWithOp("OpA", *graph.graph());
EXPECT_EQ(result_indices_new.size(), 1);
EXPECT_EQ(result_indices_new.at(0), 0);
}

View File

@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
const FunctionDef& stateless_function,
MutableGraphView* graph) {
NodeDef stateless_map;
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(),
&stateless_map);
stateless_map.set_op("MapDataset");
@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
MutableGraphView* graph) {
NodeDef random_dataset;
random_dataset.set_op("RandomDataset");
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(),
&random_dataset);
const auto* seed = graph_utils::AddScalarConstNode<int64>(
@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
NodeDef batch_dataset;
batch_dataset.set_op("BatchDatasetV2");
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(),
&batch_dataset);
const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
MutableGraphView* graph) {
NodeDef zip_node;
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(),
&zip_node);
zip_node.set_op("ZipDataset");

View File

@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kInsertOpName);
graph_utils::SetUniqueGraphNodeName(
strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(),
&new_node);
strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node);
// Set the input of LatencyDataset node as `node`
new_node.add_input(node.name());
@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
// node corresponds to a `Dataset` op.
continue;
}
GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0);
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(node.name(), 0);
auto fanout = graph.GetFanout(output_port);
if (fanout.size() > 1) {
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();

View File

@ -29,7 +29,7 @@ namespace {
NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef numa_aware_node = node;
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(),
&numa_aware_node);
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
return numa_aware_node;

View File

@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
&new_node);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node);
// Set the `input` input argument.
new_node.add_input(map_node.input(0));

View File

@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) {
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
}
} // namespace

View File

@ -37,8 +37,7 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
const FunctionDef& fused_function,
MutableGraphView* graph) {
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
@ -72,8 +71,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node,
const NodeDef& filter_node,
MutableGraphView* graph) {
NodeDef filter_by_component;
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent",
graph->GetGraph(), &filter_by_component);
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(),
&filter_by_component);
filter_by_component.set_op("FilterByLastComponentDataset");
filter_by_component.add_input(fused_map_node.name());

View File

@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
const FunctionDef& fused_function,
MutableGraphView* graph) {
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));

View File

@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function,
NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
NodeDef parallel_map = map_node;
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(),
&parallel_map);
parallel_map.set_op("ParallelMapDataset");
// TODO(b/114475558): We want to set `num_parallel_calls` to a special value,

View File

@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
MutableGraphView* graph) {
NodeDef batch_node;
batch_node.set_op(old_batch_node.op());
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(),
&batch_node);
// Set the `input_dataset` input argument
@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
MutableGraphView* graph) {
NodeDef map_node;
map_node.set_op(old_map_node.op());
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
&map_node);
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node);
// Set the `input_dataset` input argument
map_node.add_input(new_batch_node.name());

View File

@ -30,7 +30,7 @@ namespace tensorflow {
namespace grappler {
namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
const auto& count_node = *graph.GetNode(take_node.input(1));
@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) {
return node.attr().at("value").tensor().int64_val(0) == value;
}
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
// We are looking only for skip(0) nodes.
return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
// We are looking only for repeat(1) nodes.
return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) {
bool IsPrefetchZero(const NodeDef& prefetch_node,
const MutableGraphView& graph) {
if (prefetch_node.op() != "PrefetchDataset") return false;
// We are looking only for prefetch(0) nodes.
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
}

View File

@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
}
} // namespace

View File

@ -31,8 +31,8 @@ limitations under the License.
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
@ -219,8 +219,7 @@ class FunctionOptimizerContext {
: grappler_item_id_(item.id),
graph_version_(item.graph.versions().producer()),
function_library_(OpRegistry::Global(), item.graph.library()),
// GraphView doesn't not modify the graph or the nodes.
graph_view_(const_cast<GraphDef*>(&item.graph)) {
graph_view_(&item.graph) {
InitializeTrulyConstNodes(item);
InitializeInlinedFunctions(opt_level, item);
InitializeFetchNodes(item);
@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Function specialization might change the number of function outputs, so we
// have to process the final optimized graph and update all the node mapping.
if (ctx.RequiresOutputMapping()) {
GraphView optimized_graph_view(optimized_graph);
MutableGraphView optimized_graph_view(optimized_graph);
for (const auto& output_mapping : ctx.output_mappings()) {
const auto& node_name = output_mapping.first;
const auto& mappings = output_mapping.second;
@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
int to = mapping.second;
// Get the output port corresponding to the old output position.
GraphView::OutputPort from_port =
MutableGraphView::OutputPort from_port =
optimized_graph_view.GetOutputPort(node_name, from);
// Update all input ports that read from old output port.
for (GraphView::InputPort to_port :
for (MutableGraphView::InputPort to_port :
optimized_graph_view.GetFanout(from_port)) {
*to_port.node->mutable_input(to_port.port_id) =
strings::StrCat(node_name, ":", to);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@ -29,8 +30,8 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
return Status::OK();
}
Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
const NodeMap& node_map, DeviceBase* cpu_device,
ResourceMgr* resource_mgr, bool* has_dead_fanout,
int* dead_fanout) {
Status CheckForDeadFanout(const MutableGraphView& view,
const NodeDef& switch_node, const NodeMap& node_map,
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
bool* has_dead_fanout, int* dead_fanout) {
*has_dead_fanout = false;
GraphView::InputPort switch_loopcond_port(&switch_node, 1);
NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node;
const NodeDef* switch_predicate =
view.GetRegularFanin(switch_loopcond_port).node;
// CASE 1: Control is a constant.
if (IsConstant(*switch_predicate)) {
@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
}
GraphView::InputPort switch_input_port(&switch_node, 0);
NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
// CASE 2: Zero-iteration while loop.
// We check if its a while loop such that the condition is a simple binary
@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches(
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
// TODO(bsteiner): also rewrite switches as identity. For now we just record
// them
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
identity_switches;
absl::flat_hash_set<GraphView::OutputPort> identity_switches;
GraphView view(optimized_graph);
MutableGraphView view(optimized_graph);
for (const NodeDef& node : optimized_graph->node()) {
if (!IsSwitch(node)) {
continue;
@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches(
if (!has_dead_fanout) {
continue;
}
GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
GraphView::OutputPort dead(&node, dead_fanout);
identity_switches.insert(dead);
SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
for (const GraphView::InputPort& port : view.GetFanout(dead)) {
SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
zombie_inputs;
for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
if (dead_nodes.find(port.node) == dead_nodes.end()) {
zombie_inputs.PushBack(port);
}
@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches(
dead_merge_inputs;
bool found_node_to_preserve = false;
while (!found_node_to_preserve && !zombie_inputs.Empty()) {
GraphView::InputPort dead = zombie_inputs.PopBack();
MutableGraphView::InputPort dead = zombie_inputs.PopBack();
if (nodes_to_preserve.find(dead.node->name()) !=
nodes_to_preserve.end()) {
found_node_to_preserve = true;
@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches(
found_node_to_preserve = true;
break;
}
GraphView::OutputPort value_index(dead.node, 1);
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
index_fanout = view.GetFanout(value_index);
MutableGraphView::OutputPort value_index(dead.node, 1);
const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
view.GetFanout(value_index);
if (!index_fanout.empty()) {
// The 2nd output (that indicates which input is propagated) is
// connected. This never happens in practice, so we'll just skip this
@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches(
}
if (fully_dead) {
local_dead_nodes.insert(dead.node);
for (const GraphView::InputPort& port :
for (const MutableGraphView::InputPort& port :
view.GetFanouts(*dead.node, true)) {
zombie_inputs.PushBack(port);
}
@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches(
break;
} else {
if (local_dead_nodes.insert(dead.node).second) {
for (const GraphView::InputPort& dead_fanout :
for (const MutableGraphView::InputPort& dead_fanout :
view.GetFanouts(*dead.node, true)) {
zombie_inputs.PushBack(dead_fanout);
}

View File

@ -30,8 +30,8 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_memory.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
// Look for AddN nodes (and equivalent) and record input names.
GraphView view(&item->graph);
MutableGraphView view(&item->graph);
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
for (NodeDef& node : *item->graph.mutable_node()) {
@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
const string node_name = NodeName(input);
NodeDef* node = view.GetNode(node_name);
const NodeDef* node = view.GetNode(node_name);
input_topo_index.push_back(topo_order.at(node));
}
int min_input_topo_index = INT_MAX;
@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger(
return nullptr;
}
static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
static bool IsSwappable(const MutableGraphView& graph,
MutableGraphView::OutputPort output) {
const NodeDef& node = *output.node;
// There is no point in swapping out persistent tensors, since the tensor will
// continue to use memory.
@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
// If placed on the same device, these nodes are just forwarding references
// to their input. Therefore they are swappable iff their fanin is swappable
// or it resides on a different device.
GraphView::InputPort input;
MutableGraphView::InputPort input;
input.node = output.node;
input.port_id = 0;
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
if (fanin.node->device() == node.device()) {
return IsSwappable(graph, fanin);
}
@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
}
static NodeDef* FindSwapOutTrigger(
const NodeDef* node, int input_id, const GraphView& view,
const NodeDef* node, int input_id, const MutableGraphView& view,
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
execution_times) {
// Find the output port that generated the tensor to swap.
GraphView::InputPort swap;
MutableGraphView::InputPort swap;
swap.node = const_cast<NodeDef*>(node);
swap.port_id = input_id;
GraphView::OutputPort generator = view.GetRegularFanin(swap);
MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
if (!generator.node) {
return nullptr;
}
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
view.GetFanout(generator);
NodeDef* trigger = nullptr;
Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger(
return trigger;
}
static bool IsSwappable(GraphView::InputPort input) {
static bool IsSwappable(MutableGraphView::InputPort input) {
const NodeDef& node = *input.node;
const OpDef* op_def;
@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) {
}
struct MemInfo {
GraphView::OutputPort port;
MutableGraphView::OutputPort port;
int64 memory_used;
std::vector<GraphView::InputPort> uses_left;
std::vector<MutableGraphView::InputPort> uses_left;
double fitness;
bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates(
std::vector<MemInfo> mem_state;
GraphView graph(&item->graph);
MutableGraphView graph(&item->graph);
for (const auto& live_tensor : mem_usage.live_tensors) {
if (live_tensor.memory_used <= 1024) {
// Don't bother with small tensors.
@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates(
if (skip_list->find(live_tensor.node) != skip_list->end()) {
continue;
}
GraphView::OutputPort port =
MutableGraphView::OutputPort port =
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
if (!IsSwappable(graph, port)) {
continue;
@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates(
Costs::Duration allocation_time = live_tensor.allocation_time;
Costs::Duration earliest_use(Costs::Duration::infinity());
bool valid = true;
for (GraphView::InputPort input : graph.GetFanout(port)) {
for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
// Get execution time.
auto it = op_completion_times.find(input.node->name());
if (it == op_completion_times.end()) {
@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates(
// the values do not fit into any integral type.
mem_info.fitness =
MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info);
@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates(
std::sort(mem_state.begin(), mem_state.end());
for (const MemInfo& mem_info : mem_state) {
for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
for (const MutableGraphView::InputPort fanout_to_swap :
mem_info.uses_left) {
VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
<< fanout_to_swap.port_id << " of tensor "
<< mem_info.port.node->name() << ":" << mem_info.port.port_id
@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
for (const auto& node : item->graph.node()) {
name_map[node.name()] = &node;
}
GraphView view(&item->graph);
MutableGraphView view(&item->graph);
bool updated_graph = false;

View File

@ -18,8 +18,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphProperties properties(item);
bool inferred_properties = false;
GraphView graph(optimized_graph);
MutableGraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
// simply as the size of the tensor.
@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsShape(node)) {
continue;
}
for (GraphView::InputPort fanout :
graph.GetFanout(GraphView::OutputPort(&node, 0))) {
for (MutableGraphView::InputPort fanout :
graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
if (fanout.node->op() != "Prod") {
continue;
}
@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// rewrite the whole expression directly as a Size operation.
continue;
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
const MutableGraphView::OutputPort reduce_indices =
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
if (!inferred_properties) {
// Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false));
@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// is possible whenever the symbolic dimensions in the numerator and
// denominator cancel each other.
if (node.op() == "Div") {
const GraphView::OutputPort input1 =
graph.GetRegularFanin(GraphView::InputPort(&node, 0));
const GraphView::OutputPort input2 =
graph.GetRegularFanin(GraphView::InputPort(&node, 1));
const MutableGraphView::OutputPort input1 =
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
const MutableGraphView::OutputPort input2 =
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}

View File

@ -101,6 +101,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:mutable_graph_view",
"@com_google_absl//absl/container:flat_hash_map",
],
)

View File

@ -21,8 +21,11 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
void ReverseDfs(
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
namespace {
template <typename GraphViewType>
void ReverseDfsInternal(
const GraphViewType& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
@ -79,5 +82,25 @@ void ReverseDfs(
}
}
} // namespace
void ReverseDfs(
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
ReverseDfsInternal<GraphView>(graph_view, from, pre_order, post_order,
on_back_edge);
}
void ReverseDfs(
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
ReverseDfsInternal<MutableGraphView>(graph_view, from, pre_order, post_order,
on_back_edge);
}
} // namespace grappler
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <functional>
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
namespace tensorflow {
namespace grappler {
@ -34,6 +35,12 @@ void ReverseDfs(
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
void ReverseDfs(
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
const std::function<void(const NodeDef*)>& pre_order,
const std::function<void(const NodeDef*)>& post_order,
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
} // namespace grappler
} // namespace tensorflow

View File

@ -14,9 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/traversal.h"
//#include "tensorflow/core/framework/node_def.pb.h"
//#include "tensorflow/core/lib/core/status_test_util.h"
//#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) {
found_back_edge = true;
});
EXPECT_EQ(std::vector<string>({"1", "4", "3", "2", "5", "0"}), pre_order);
EXPECT_EQ(std::vector<string>({"4", "5", "2", "3", "1", "0"}), post_order);
// Pre/Post order traversals are non deterministic because a node fanin is an
// absl::flat_hash_set with non deterministic traversal order.
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
std::set<ValidTraversal> valid_traversals = {
// pre_order post_order
{{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}},
{{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}};
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
EXPECT_FALSE(found_back_edge);
}
@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) {
back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
});
EXPECT_EQ(std::vector<string>({"6", "3", "2", "1", "5", "4"}), pre_order);
EXPECT_EQ(std::vector<string>({"1", "4", "5", "2", "3", "6"}), post_order);
// Pre/Post order traversals are non deterministic because a node fanin is an
// absl::flat_hash_set with non deterministic traversal order.
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
std::set<ValidTraversal> valid_traversals = {
// pre_order post_order
{{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}},
{{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}},
{{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}};
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
}