[Grappler] Preserve constness of the GraphDef in GraphView.
1. Split GrapView into GraphView and MutableGraphView with separate {Input/Output}Port types with different node pointer constness. 2. Properly use GraphView and MutableGraphView in graph properties, and get rid of const_cast. 3. Remove const_cast in function optimizer. 4. Migrate GraphView to absl containers and hash PiperOrigin-RevId: 219488040
This commit is contained in:
parent
92e604060a
commit
3eeaf9f1e1
@ -69,6 +69,9 @@ cc_library(
|
||||
":utils",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/hash",
|
||||
],
|
||||
)
|
||||
|
||||
@ -82,6 +85,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -44,7 +44,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
@ -456,10 +456,10 @@ class SymbolicShapeRefiner {
|
||||
const GraphView& graph,
|
||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
|
||||
: graph_(graph),
|
||||
function_library_(OpRegistry::Global(), graph.GetGraph()->library()),
|
||||
function_library_(OpRegistry::Global(), graph.graph()->library()),
|
||||
fed_ports_(fed_ports) {
|
||||
graph_def_version_ = graph.GetGraph()->versions().producer();
|
||||
node_to_context_.reserve(graph.GetGraph()->node_size());
|
||||
graph_def_version_ = graph.graph()->versions().producer();
|
||||
node_to_context_.reserve(graph.graph()->node_size());
|
||||
}
|
||||
|
||||
const GraphView& graph() const { return graph_; }
|
||||
@ -512,7 +512,7 @@ class SymbolicShapeRefiner {
|
||||
// Placeholder with Const) don't affect one in
|
||||
// fun_to_grappler_function_item_.
|
||||
GrapplerFunctionItem grappler_function_item = it->second;
|
||||
GraphView gv(&grappler_function_item.graph);
|
||||
MutableGraphView gv(&grappler_function_item.graph);
|
||||
|
||||
// Forward shapes from function input nodes to argument nodes.
|
||||
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
|
||||
@ -532,7 +532,7 @@ class SymbolicShapeRefiner {
|
||||
"Function inputs should not contain control nodes.");
|
||||
}
|
||||
|
||||
NodeDef* input_node = graph_.GetNode(node_name);
|
||||
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||
if (input_node == nullptr) {
|
||||
return errors::FailedPrecondition(node_name,
|
||||
" was not found in the graph.");
|
||||
@ -566,7 +566,7 @@ class SymbolicShapeRefiner {
|
||||
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
||||
const string& input = function_node->input(i);
|
||||
const string& node_name = NodeName(input);
|
||||
NodeDef* input_node = graph_.GetNode(node_name);
|
||||
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||
if (IsConstant(*input_node)) {
|
||||
TF_CHECK_OK(
|
||||
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
|
||||
@ -1441,8 +1441,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
|
||||
continue;
|
||||
}
|
||||
ShapeHandle input = in->output(fanin.src.port_id);
|
||||
CHECK_EQ(fanin.tgt.node, node);
|
||||
c->SetInput(fanin.tgt.port_id, input);
|
||||
CHECK_EQ(fanin.dst.node, node);
|
||||
c->SetInput(fanin.dst.port_id, input);
|
||||
if (!out_initialized) {
|
||||
out_initialized = true;
|
||||
out = input;
|
||||
@ -1673,7 +1673,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
|
||||
}
|
||||
}
|
||||
|
||||
GraphView graph_view(const_cast<GraphDef*>(&item_.graph));
|
||||
GraphView graph_view(&item_.graph);
|
||||
|
||||
// List the resources and the nodes using them. Also collect the Merge nodes,
|
||||
// fed nodes, and primary inputs.
|
||||
@ -1725,10 +1725,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
|
||||
for (const auto& resource : resources) {
|
||||
for (const NodeDef* src : resource.second.first) {
|
||||
resource_handles[src] = resource.first;
|
||||
for (const NodeDef* tgt : resource.second.second) {
|
||||
for (const NodeDef* dst : resource.second.second) {
|
||||
// Add control edges from enqueue to dequeue nodes to ensure they are
|
||||
// processed in their logical order.
|
||||
extra_deps.emplace_back(src, tgt);
|
||||
extra_deps.emplace_back(src, dst);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
|
||||
return OpPortIdToArgId(node, op.input_arg(), port_id);
|
||||
}
|
||||
|
||||
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
||||
for (int i = 0; i < graph_->node_size(); i++) {
|
||||
auto node = graph_->mutable_node(i);
|
||||
AddUniqueNodeOrDie(node);
|
||||
}
|
||||
|
||||
for (NodeDef& node : *graph_->mutable_node()) {
|
||||
AddFanouts(&node);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
|
||||
auto result = nodes_.emplace(node->name(), node);
|
||||
// Check that the graph doesn't contain multiple nodes with the same name.
|
||||
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
||||
}
|
||||
|
||||
void GraphView::AddFanouts(NodeDef* node) {
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
OutputPort fanin;
|
||||
const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
|
||||
fanin.node = nodes_[fanin_name];
|
||||
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
if (fanin.port_id < 0) {
|
||||
input.port_id = -1;
|
||||
} else {
|
||||
input.port_id = i;
|
||||
num_regular_outputs_[fanin.node] =
|
||||
std::max(num_regular_outputs_[fanin.node], fanin.port_id);
|
||||
}
|
||||
|
||||
fanouts_[fanin].insert(input);
|
||||
}
|
||||
}
|
||||
|
||||
NodeDef* GraphView::GetNode(const string& node_name) const {
|
||||
auto it = nodes_.find(node_name);
|
||||
if (it == nodes_.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
|
||||
int port_id) const {
|
||||
InputPort result;
|
||||
result.node = GetNode(node_name);
|
||||
// TODO(bsteiner): verify that the node has at least port_id input ports
|
||||
result.port_id = port_id;
|
||||
return result;
|
||||
}
|
||||
|
||||
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
|
||||
int port_id) const {
|
||||
OutputPort result;
|
||||
result.node = GetNode(node_name);
|
||||
// TODO(bsteiner): verify that the node has at least port_id output ports
|
||||
result.port_id = port_id;
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
|
||||
GraphView::GetFanout(const GraphView::OutputPort& port) const {
|
||||
auto it = fanouts_.find(port);
|
||||
if (it == fanouts_.end()) {
|
||||
return empty_set_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
||||
GraphView::GetFanin(const GraphView::InputPort& port) const {
|
||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort> result;
|
||||
if (port.port_id >= 0) {
|
||||
result.insert(GetRegularFanin(port));
|
||||
} else {
|
||||
for (int i = port.node->input_size() - 1; i >= 0; --i) {
|
||||
OutputPort fanin;
|
||||
string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id);
|
||||
if (fanin.port_id < 0) {
|
||||
auto it = nodes_.find(fanin_name);
|
||||
if (it != nodes_.end()) {
|
||||
fanin.node = it->second;
|
||||
result.insert(fanin);
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const GraphView::OutputPort GraphView::GetRegularFanin(
|
||||
const GraphView::InputPort& port) const {
|
||||
CHECK_LE(0, port.port_id);
|
||||
OutputPort fanin;
|
||||
string fanin_name =
|
||||
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
|
||||
auto it = nodes_.find(fanin_name);
|
||||
if (it == nodes_.end()) {
|
||||
fanin.node = nullptr;
|
||||
} else {
|
||||
fanin.node = it->second;
|
||||
}
|
||||
return fanin;
|
||||
}
|
||||
|
||||
std::unordered_set<GraphView::InputPort, GraphView::HashPort>
|
||||
GraphView::GetFanouts(const NodeDef& node,
|
||||
bool include_controlled_nodes) const {
|
||||
std::unordered_set<InputPort, HashPort> result;
|
||||
OutputPort port;
|
||||
port.node = const_cast<NodeDef*>(&node);
|
||||
const int first_port_id = include_controlled_nodes ? -1 : 0;
|
||||
auto it = num_regular_outputs_.find(&node);
|
||||
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
auto it = fanouts_.find(port);
|
||||
if (it != fanouts_.end()) {
|
||||
result.insert(it->second.begin(), it->second.end());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
||||
GraphView::GetFanins(const NodeDef& node,
|
||||
bool include_controlling_nodes) const {
|
||||
std::unordered_set<OutputPort, HashPort> result;
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
OutputPort fanin;
|
||||
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
|
||||
if (fanin.port_id < 0) {
|
||||
if (!include_controlling_nodes) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto it = nodes_.find(fanin_name);
|
||||
if (it != nodes_.end()) {
|
||||
fanin.node = it->second;
|
||||
result.insert(fanin);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
int GraphView::NumFanins(const NodeDef& node,
|
||||
bool include_controlling_nodes) const {
|
||||
int count = 0;
|
||||
for (const string& input : node.input()) {
|
||||
if (!include_controlling_nodes && IsControlInput(input)) {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
|
||||
GraphView::GetFanoutEdges(const NodeDef& node,
|
||||
bool include_controlled_edges) const {
|
||||
std::unordered_set<Edge, HashEdge> result;
|
||||
OutputPort port;
|
||||
port.node = const_cast<NodeDef*>(&node);
|
||||
const int first_port_id = include_controlled_edges ? -1 : 0;
|
||||
auto it = num_regular_outputs_.find(&node);
|
||||
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
auto it = fanouts_.find(port);
|
||||
if (it != fanouts_.end()) {
|
||||
Edge fanout;
|
||||
fanout.src.node = const_cast<NodeDef*>(&node);
|
||||
fanout.src.port_id = i;
|
||||
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
|
||||
fanout.tgt = *itr;
|
||||
result.insert(fanout);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
|
||||
GraphView::GetFaninEdges(const NodeDef& node,
|
||||
bool include_controlling_edges) const {
|
||||
std::unordered_set<Edge, HashEdge> result;
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
Edge fanin;
|
||||
fanin.tgt.node = const_cast<NodeDef*>(&node);
|
||||
fanin.tgt.port_id = i;
|
||||
string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
|
||||
if (fanin.src.port_id < 0) {
|
||||
if (!include_controlling_edges) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
auto it = nodes_.find(fanin_name);
|
||||
if (it != nodes_.end()) {
|
||||
fanin.src.node = it->second;
|
||||
result.insert(fanin);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -18,9 +18,16 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/hash/hash.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -36,114 +43,290 @@ namespace grappler {
|
||||
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||
int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||
|
||||
// A utility class to simplify the traversal of a GraphDef.
|
||||
class GraphView {
|
||||
namespace internal {
|
||||
|
||||
// GraphViewInternal is a helper class to simplify graph traversal. It creates
|
||||
// an immutable view of the nodes and edges represented by a GraphDef protocol
|
||||
// buffer.
|
||||
//
|
||||
// There are two public classes implementing GraphViewInternal:
|
||||
//
|
||||
// - GraphView: constructed from the `const GraphDef` and doesn't allow
|
||||
// to mutate underlying graph via input/output ports lookup functions (ports
|
||||
// have const pointers to nodes).
|
||||
//
|
||||
// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
|
||||
// the graph via input/output ports lookup functions (ports have non-const
|
||||
// pointers to nodes), and also have couple additional functions to
|
||||
// add/remove/replace nodes in the graph.
|
||||
//
|
||||
// --------------------------- !!! WARNING !!! ---------------------------------
|
||||
// Removing nodes from the graph outside of MutableGraphView will
|
||||
// lead to segfaults! Guaranteed by absl::string_view!
|
||||
// -----------------------------------------------------------------------------
|
||||
//
|
||||
template <typename GraphDefT, typename NodeDefT>
|
||||
class GraphViewInternal {
|
||||
public:
|
||||
struct Port {
|
||||
Port() = default;
|
||||
Port(NodeDef* n, int port) : node(n), port_id(port) {}
|
||||
|
||||
// TODO(prazek): ports should keep the constness of GraphView. The only way
|
||||
// to modify graph through the view should be using MutableGraphView.
|
||||
NodeDef* node = nullptr;
|
||||
int port_id = -1;
|
||||
Port() : node(nullptr), port_id(0) {}
|
||||
Port(NodeDefT* n, int port) : node(n), port_id(port) {}
|
||||
|
||||
bool operator==(const Port& other) const {
|
||||
return node == other.node && port_id == other.port_id;
|
||||
}
|
||||
};
|
||||
struct InputPort : public Port {
|
||||
InputPort() = default;
|
||||
InputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
|
||||
InputPort(const NodeDef* n, int port_id)
|
||||
: Port(const_cast<NodeDef*>(n), port_id) {}
|
||||
};
|
||||
struct OutputPort : public Port {
|
||||
OutputPort() = default;
|
||||
OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Port& p) {
|
||||
return H::combine(std::move(h), p.node, p.port_id);
|
||||
}
|
||||
|
||||
NodeDefT* node;
|
||||
int port_id;
|
||||
};
|
||||
|
||||
struct HashPort {
|
||||
std::size_t operator()(const Port& port) const {
|
||||
return reinterpret_cast<std::size_t>(port.node) + port.port_id;
|
||||
}
|
||||
struct InputPort : public Port {
|
||||
using Port::Port;
|
||||
};
|
||||
|
||||
struct OutputPort : public Port {
|
||||
using Port::Port;
|
||||
};
|
||||
|
||||
struct Edge {
|
||||
OutputPort src;
|
||||
InputPort tgt;
|
||||
Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
|
||||
|
||||
bool operator==(const Edge& other) const {
|
||||
return src == other.src && tgt == other.tgt;
|
||||
return src == other.src && dst == other.dst;
|
||||
}
|
||||
};
|
||||
struct HashEdge {
|
||||
std::size_t operator()(const Edge& edge) const {
|
||||
return HashPort()(edge.src) + HashPort()(edge.tgt);
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const Edge& e) {
|
||||
return H::combine(std::move(h), e.src, e.dst);
|
||||
}
|
||||
|
||||
OutputPort src;
|
||||
InputPort dst;
|
||||
};
|
||||
|
||||
explicit GraphView(GraphDef* graph);
|
||||
GraphDef* GetGraph() const { return graph_; }
|
||||
NodeDef* GetNode(const string& node_name) const;
|
||||
GraphDefT* graph() const { return graph_; }
|
||||
|
||||
// Find a node by name or return `nullptr` if it's not in a graph view.
|
||||
NodeDefT* GetNode(absl::string_view node_name) const {
|
||||
return gtl::FindWithDefault(nodes_, node_name, nullptr);
|
||||
}
|
||||
|
||||
// Get the specified input port. Note that the special '-1' port_id can be
|
||||
// used to access the controlling nodes (i.e. the nodes connected to node_name
|
||||
// through an incoming control dependency).
|
||||
InputPort GetInputPort(const string& node_name, int port_id) const;
|
||||
InputPort GetInputPort(absl::string_view node_name, int port_id) const {
|
||||
return InputPort(GetNode(node_name), port_id);
|
||||
}
|
||||
|
||||
// Get the specified output port. Note that the special '-1' port_id can be
|
||||
// used to access the controlled nodes (i.e. the nodes connected to node_name
|
||||
// through an outgoing control dependency).
|
||||
OutputPort GetOutputPort(const string& node_name, int port_id) const;
|
||||
OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
|
||||
return OutputPort(GetNode(node_name), port_id);
|
||||
}
|
||||
|
||||
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
|
||||
// of an output (resp. input) port.
|
||||
const std::unordered_set<InputPort, HashPort>& GetFanout(
|
||||
const OutputPort& port) const;
|
||||
std::unordered_set<OutputPort, HashPort> GetFanin(
|
||||
const InputPort& port) const;
|
||||
const absl::flat_hash_set<InputPort>& GetFanout(
|
||||
const OutputPort& port) const {
|
||||
return gtl::FindWithDefault(fanouts_, port, empty_set_);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
|
||||
if (port.port_id >= 0) return {GetRegularFanin(port)};
|
||||
|
||||
// Collect fanin for the control input.
|
||||
absl::flat_hash_set<OutputPort> result;
|
||||
for (int i = port.node->input_size() - 1; i >= 0; --i) {
|
||||
TensorId tensor_id = ParseTensorName(port.node->input(i));
|
||||
if (tensor_id.index() >= 0) break; // we reached regular inputs
|
||||
|
||||
auto it = nodes_.find(tensor_id.node());
|
||||
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Special case: regular (i.e. non-control) input ports can only have one
|
||||
// fanin.
|
||||
const OutputPort GetRegularFanin(const InputPort& port) const;
|
||||
const OutputPort GetRegularFanin(const InputPort& port) const {
|
||||
DCHECK_GE(port.port_id, 0);
|
||||
if (port.port_id < 0) return OutputPort();
|
||||
|
||||
// Get all the input (resp. output) ports in the immediate fanout (resp fanin)
|
||||
// of a node. Include the controlling nodes iff include_controlling_nodes is
|
||||
// true.
|
||||
std::unordered_set<InputPort, HashPort> GetFanouts(
|
||||
const NodeDef& node, bool include_controlled_nodes) const;
|
||||
std::unordered_set<OutputPort, HashPort> GetFanins(
|
||||
const NodeDef& node, bool include_controlling_nodes) const;
|
||||
TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
|
||||
return GetOutputPort(tensor_id.node(), tensor_id.index());
|
||||
}
|
||||
|
||||
// Get all the input (resp. output) ports in the immediate fanout (resp
|
||||
// fanin) of a node. Include the controlling nodes iff
|
||||
// include_controlling_nodes is true.
|
||||
absl::flat_hash_set<InputPort> GetFanouts(
|
||||
const NodeDef& node, bool include_controlled_nodes) const {
|
||||
absl::flat_hash_set<InputPort> result;
|
||||
|
||||
OutputPort port;
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlled_nodes ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
auto it = fanouts_.find(port);
|
||||
if (it != fanouts_.end()) {
|
||||
result.insert(it->second.begin(), it->second.end());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<OutputPort> GetFanins(
|
||||
const NodeDef& node, bool include_controlling_nodes) const {
|
||||
absl::flat_hash_set<OutputPort> result;
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node.input(i));
|
||||
if (tensor_id.index() < 0 && !include_controlling_nodes) break;
|
||||
|
||||
auto it = nodes_.find(tensor_id.node());
|
||||
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Get the number of ports in the immediate fanin of a node. Count the
|
||||
// controlling nodes iff include_controlling_nodes is true.
|
||||
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const;
|
||||
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const {
|
||||
int count = 0;
|
||||
for (const string& input : node.input()) {
|
||||
if (!include_controlling_nodes && IsControlInput(input)) {
|
||||
break;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
// Get all the edge in the immediate fanout (resp fanin) of a node. Include
|
||||
// the control edges iff include_controlling_edges is true.
|
||||
std::unordered_set<Edge, HashEdge> GetFanoutEdges(
|
||||
const NodeDef& node, bool include_controlled_edges) const;
|
||||
std::unordered_set<Edge, HashEdge> GetFaninEdges(
|
||||
const NodeDef& node, bool include_controlling_edges) const;
|
||||
// Get the number of ports in the immediate fanout of a node. Count the
|
||||
// controlling nodes iff include_controlling_nodes is true.
|
||||
int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const {
|
||||
int count = 0;
|
||||
|
||||
OutputPort port;
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlling_nodes ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
auto it = fanouts_.find(port);
|
||||
if (it != fanouts_.end()) count += it->second.size();
|
||||
}
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
// Get all the edges in the immediate fanout (resp fanin) of a node.
|
||||
// Include the control edges iff include_controlling_edges is true.
|
||||
absl::flat_hash_set<Edge> GetFanoutEdges(
|
||||
const NodeDef& node, bool include_controlled_edges) const {
|
||||
absl::flat_hash_set<Edge> result;
|
||||
|
||||
OutputPort port;
|
||||
port.node = const_cast<NodeDefT*>(&node);
|
||||
const int first_port_id = include_controlled_edges ? -1 : 0;
|
||||
const int last_port_id =
|
||||
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
|
||||
|
||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||
port.port_id = i;
|
||||
auto it = fanouts_.find(port);
|
||||
if (it != fanouts_.end()) {
|
||||
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
|
||||
result.emplace(/*src*/ OutputPort(const_cast<NodeDefT*>(&node), i),
|
||||
/*dst*/ *itr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<Edge> GetFaninEdges(
|
||||
const NodeDef& node, bool include_controlling_edges) const {
|
||||
absl::flat_hash_set<Edge> result;
|
||||
for (int i = 0; i < node.input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node.input(i));
|
||||
if (tensor_id.index() < 0 && !include_controlling_edges) break;
|
||||
|
||||
auto it = nodes_.find(tensor_id.node());
|
||||
if (it != nodes_.end()) {
|
||||
result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()),
|
||||
/*dst*/ InputPort(const_cast<NodeDefT*>(&node), i));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Add a new `node` to the graph.
|
||||
void AddUniqueNodeOrDie(NodeDef* node);
|
||||
// Add fanout to every `node` input.
|
||||
void AddFanouts(NodeDef* node);
|
||||
std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
|
||||
GraphDef* MutableGraph() { return graph_; }
|
||||
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
|
||||
|
||||
using FanoutsMapType =
|
||||
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
|
||||
HashPort>;
|
||||
FanoutsMapType* MutableFanouts() { return &fanouts_; }
|
||||
void AddUniqueNodeOrDie(NodeDefT* node) {
|
||||
auto result = nodes_.emplace(node->name(), node);
|
||||
// TODO(ezhulenev): Replace CHECK with factory method returning
|
||||
// absl::StatusOr (when available).
|
||||
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
||||
}
|
||||
|
||||
void AddFanouts(NodeDefT* node) {
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node->input(i));
|
||||
OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
|
||||
|
||||
if (output.port_id < 0) {
|
||||
fanouts_[output].emplace(node, -1);
|
||||
} else {
|
||||
num_regular_outputs_[output.node] =
|
||||
std::max(num_regular_outputs_[output.node], output.port_id);
|
||||
fanouts_[output].emplace(node, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Access to the mutable internal state for MutableGraphView.
|
||||
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
|
||||
return &nodes_;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
|
||||
mutable_fanouts() {
|
||||
return &fanouts_;
|
||||
}
|
||||
|
||||
private:
|
||||
GraphDef* graph_;
|
||||
std::unordered_map<string, NodeDef*> nodes_;
|
||||
std::unordered_set<InputPort, HashPort> empty_set_;
|
||||
FanoutsMapType fanouts_;
|
||||
std::unordered_map<const NodeDef*, int> num_regular_outputs_;
|
||||
GraphDefT* graph_; // must outlive the graph view
|
||||
absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
|
||||
absl::flat_hash_set<InputPort> empty_set_;
|
||||
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
|
||||
std::unordered_map<NodeDefT*, int> num_regular_outputs_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Immutable GraphView that keeps the constness of the GraphDef. If you need to
|
||||
// mutate the graph or the nodes via the graph view lookup functions, see
|
||||
// MutableGraphView.
|
||||
class GraphView
|
||||
: public internal::GraphViewInternal<const GraphDef, const NodeDef> {
|
||||
public:
|
||||
explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
|
||||
for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
|
||||
for (const NodeDef& node : graph->node()) AddFanouts(&node);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/ops/parsing_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) {
|
||||
|
||||
const NodeDef* add_node = graph.GetNode("AddN");
|
||||
EXPECT_NE(nullptr, add_node);
|
||||
string fanouts;
|
||||
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
|
||||
strings::StrAppend(&fanouts,
|
||||
strings::StrCat(fo.node->name(), ":", fo.port_id, " "));
|
||||
}
|
||||
EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts);
|
||||
|
||||
string fanins;
|
||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||
strings::StrAppend(&fanins,
|
||||
strings::StrCat(fi.node->name(), ":", fi.port_id, " "));
|
||||
absl::flat_hash_set<string> fanouts;
|
||||
absl::flat_hash_set<string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
|
||||
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
|
||||
fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id));
|
||||
}
|
||||
EXPECT_EQ("Square_1:0 Square:0 ", fanins);
|
||||
EXPECT_EQ(graph.NumFanouts(*add_node, false), 2);
|
||||
EXPECT_EQ(fanouts, expected_fanouts);
|
||||
|
||||
absl::flat_hash_set<string> fanins;
|
||||
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
|
||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
||||
}
|
||||
EXPECT_EQ(graph.NumFanins(*add_node, false), 2);
|
||||
EXPECT_EQ(fanins, expected_fanins);
|
||||
}
|
||||
|
||||
TEST_F(GraphViewTest, ControlDependencies) {
|
||||
|
@ -19,8 +19,26 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
const absl::flat_hash_set<MutableGraphView::InputPort>&
|
||||
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
|
||||
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
|
||||
port.port_id));
|
||||
}
|
||||
|
||||
absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
|
||||
const GraphView::InputPort& port) const {
|
||||
return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
|
||||
port.port_id));
|
||||
}
|
||||
|
||||
const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
|
||||
const GraphView::InputPort& port) const {
|
||||
return GetRegularFanin(MutableGraphView::InputPort(
|
||||
const_cast<NodeDef*>(port.node), port.port_id));
|
||||
}
|
||||
|
||||
NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
||||
auto* node_in_graph = GetGraph()->add_node();
|
||||
auto* node_in_graph = graph()->add_node();
|
||||
*node_in_graph = std::move(node);
|
||||
|
||||
AddUniqueNodeOrDie(node_in_graph);
|
||||
@ -31,7 +49,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
||||
|
||||
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
||||
const int output_port_id) {
|
||||
auto* node_in_graph = GetGraph()->add_node();
|
||||
auto* node_in_graph = graph()->add_node();
|
||||
*node_in_graph = std::move(node);
|
||||
|
||||
AddUniqueNodeOrDie(node_in_graph);
|
||||
@ -46,8 +64,7 @@ NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
||||
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
||||
const NodeDef& new_input,
|
||||
const int output_port_id) {
|
||||
GraphView::OutputPort output_port =
|
||||
GetOutputPort(old_input.name(), output_port_id);
|
||||
OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
|
||||
auto fanout = GetFanout(output_port);
|
||||
for (auto& input_port : fanout) {
|
||||
input_port.node->set_input(input_port.port_id, new_input.name());
|
||||
@ -57,17 +74,17 @@ void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
||||
|
||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
RemoveFanouts(MutableNodes()->at(node_name_to_delete));
|
||||
RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
MutableNodes()->erase(node_name_to_delete);
|
||||
EraseNodesFromGraph(nodes_to_delete, GetGraph());
|
||||
mutable_nodes()->erase(node_name_to_delete);
|
||||
EraseNodesFromGraph(nodes_to_delete, graph());
|
||||
}
|
||||
|
||||
void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
OutputPort fanin;
|
||||
string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
|
||||
fanin.node = (*MutableNodes())[fanin_name];
|
||||
fanin.node = (*mutable_nodes())[fanin_name];
|
||||
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
@ -76,7 +93,7 @@ void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
||||
else
|
||||
input.port_id = i;
|
||||
|
||||
(*MutableFanouts())[fanin].erase(input);
|
||||
(*mutable_fanouts())[fanin].erase(input);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,11 +24,25 @@ namespace grappler {
|
||||
// A utility class to simplify the traversal of a GraphDef that, unlike
|
||||
// GraphView, supports updating the graph. Note that you should not modify the
|
||||
// graph separately, because the view will get out of sync.
|
||||
class MutableGraphView : public GraphView {
|
||||
public:
|
||||
using GraphView::GraphView;
|
||||
|
||||
GraphDef* GetGraph() { return MutableGraph(); }
|
||||
class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
public:
|
||||
explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
|
||||
for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
|
||||
for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node);
|
||||
}
|
||||
|
||||
// Lookup fanouts/fanins using immutable ports.
|
||||
using GraphViewInternal::GetFanout;
|
||||
const absl::flat_hash_set<InputPort>& GetFanout(
|
||||
const GraphView::OutputPort& port) const;
|
||||
|
||||
using GraphViewInternal::GetFanin;
|
||||
absl::flat_hash_set<OutputPort> GetFanin(
|
||||
const GraphView::InputPort& port) const;
|
||||
|
||||
using GraphViewInternal::GetRegularFanin;
|
||||
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
|
||||
|
||||
// Adds a new node to graph and updates the view.
|
||||
NodeDef* AddNode(NodeDef&& node);
|
||||
|
@ -26,7 +26,8 @@ namespace {
|
||||
bool FindChildWithName(const MutableGraphView& graph,
|
||||
const string& output_port_name,
|
||||
const string& input_name) {
|
||||
GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0);
|
||||
MutableGraphView::OutputPort output_port =
|
||||
graph.GetOutputPort(output_port_name, 0);
|
||||
auto fanout = graph.GetFanout(output_port);
|
||||
for (auto& input_port : fanout) {
|
||||
if (input_port.node->name() == input_name) return true;
|
||||
@ -59,10 +60,10 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
|
||||
GraphDef new_graph = item.graph;
|
||||
MutableGraphView graph(&new_graph);
|
||||
|
||||
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
EXPECT_EQ("AddN", input.node->name());
|
||||
EXPECT_EQ(0, input.port_id);
|
||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
EXPECT_EQ("Square", fanin.node->name());
|
||||
EXPECT_EQ(0, fanin.port_id);
|
||||
|
||||
@ -89,7 +90,7 @@ TEST(MutableGraphViewTest, InsertNodes) {
|
||||
GraphDef new_graph = item.graph;
|
||||
MutableGraphView graph(&new_graph);
|
||||
|
||||
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||
|
||||
NodeDef new_node = *input.node;
|
||||
new_node.set_name("new_node");
|
||||
|
@ -145,8 +145,8 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
@ -422,8 +422,8 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/clusters:cluster",
|
||||
@ -625,12 +625,13 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
"//tensorflow/core/grappler/utils:frame",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
@ -663,8 +664,8 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
|
@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
|
||||
const FunctionDef& fused_function,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef fused_node;
|
||||
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(),
|
||||
&fused_node);
|
||||
|
||||
fused_node.set_op("FilterDataset");
|
||||
|
@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper(
|
||||
MutableGraphView* graph) {
|
||||
NodeDef node;
|
||||
node.set_op(kConstOpName);
|
||||
SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node);
|
||||
SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
|
||||
|
||||
(*node.mutable_attr())["dtype"].set_type(dtype);
|
||||
std::unique_ptr<tensorflow::TensorProto> tensor =
|
||||
@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper(
|
||||
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
|
||||
NodeDef node;
|
||||
node.set_op("Placeholder");
|
||||
SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
|
||||
SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
|
||||
(*node.mutable_attr())["dtype"].set_type(dtype);
|
||||
TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
|
||||
shape->set_unknown_rank(false);
|
||||
@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
||||
if (!name.empty()) {
|
||||
node.set_name(string(name));
|
||||
} else {
|
||||
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
|
||||
SetUniqueGraphNodeName(op, graph->graph(), &node);
|
||||
}
|
||||
node.set_op(string(op));
|
||||
for (const string& input : inputs) {
|
||||
@ -228,7 +228,7 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
|
||||
|
||||
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
|
||||
if (node.input_size() == 0) return nullptr;
|
||||
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
||||
MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
||||
return graph.GetRegularFanin(input_port).node;
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
|
||||
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
|
||||
}
|
||||
|
||||
@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
|
||||
EXPECT_TRUE(
|
||||
ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
|
||||
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
|
||||
}
|
||||
|
||||
@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
|
||||
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
|
||||
}
|
||||
|
||||
@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
|
||||
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
|
||||
}
|
||||
|
||||
@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
|
||||
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
|
||||
}
|
||||
|
||||
@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
|
||||
EXPECT_TRUE(
|
||||
ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
|
||||
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
|
||||
}
|
||||
|
||||
@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) {
|
||||
TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
|
||||
|
||||
graph.DeleteNodes({"A"});
|
||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
|
||||
@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
|
||||
TEST(GraphUtilsTest, ContainsNodeWithOp) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
||||
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
|
||||
|
||||
graph.DeleteNodes({"A"});
|
||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, FindGraphNodeWithName) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
||||
EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||
|
||||
graph.DeleteNodes({"A"});
|
||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
||||
@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
||||
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
|
||||
|
||||
graph.DeleteNodes({"B"});
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
|
||||
}
|
||||
|
||||
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
||||
GraphDef graph_def;
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
|
||||
|
||||
AddNode("A", "OpA", {}, {}, &graph);
|
||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||
std::vector<int> result_indices =
|
||||
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
|
||||
FindAllGraphNodesWithOp("OpA", *graph.graph());
|
||||
EXPECT_EQ(result_indices.size(), 2);
|
||||
EXPECT_EQ(result_indices.at(0), 0);
|
||||
EXPECT_EQ(result_indices.at(1), 2);
|
||||
|
||||
graph.DeleteNodes({"A2"});
|
||||
std::vector<int> result_indices_new =
|
||||
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
|
||||
FindAllGraphNodesWithOp("OpA", *graph.graph());
|
||||
EXPECT_EQ(result_indices_new.size(), 1);
|
||||
EXPECT_EQ(result_indices_new.at(0), 0);
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
|
||||
const FunctionDef& stateless_function,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef stateless_map;
|
||||
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(),
|
||||
&stateless_map);
|
||||
|
||||
stateless_map.set_op("MapDataset");
|
||||
@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef random_dataset;
|
||||
random_dataset.set_op("RandomDataset");
|
||||
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(),
|
||||
&random_dataset);
|
||||
|
||||
const auto* seed = graph_utils::AddScalarConstNode<int64>(
|
||||
@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
|
||||
NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
|
||||
NodeDef batch_dataset;
|
||||
batch_dataset.set_op("BatchDatasetV2");
|
||||
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(),
|
||||
&batch_dataset);
|
||||
const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
|
||||
const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
|
||||
@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
|
||||
NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef zip_node;
|
||||
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(),
|
||||
&zip_node);
|
||||
|
||||
zip_node.set_op("ZipDataset");
|
||||
|
@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kInsertOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(
|
||||
strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(),
|
||||
&new_node);
|
||||
strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node);
|
||||
// Set the input of LatencyDataset node as `node`
|
||||
new_node.add_input(node.name());
|
||||
|
||||
@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// node corresponds to a `Dataset` op.
|
||||
continue;
|
||||
}
|
||||
GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0);
|
||||
MutableGraphView::OutputPort output_port =
|
||||
graph.GetOutputPort(node.name(), 0);
|
||||
auto fanout = graph.GetFanout(output_port);
|
||||
if (fanout.size() > 1) {
|
||||
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
|
||||
|
@ -29,7 +29,7 @@ namespace {
|
||||
|
||||
NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) {
|
||||
NodeDef numa_aware_node = node;
|
||||
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(),
|
||||
&numa_aware_node);
|
||||
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
|
||||
return numa_aware_node;
|
||||
|
@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef new_node;
|
||||
new_node.set_op(kFusedOpName);
|
||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
|
||||
&new_node);
|
||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node);
|
||||
|
||||
// Set the `input` input argument.
|
||||
new_node.add_input(map_node.input(0));
|
||||
|
@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) {
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
|
||||
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -37,8 +37,7 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
|
||||
const FunctionDef& fused_function,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef fused_node;
|
||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
|
||||
&fused_node);
|
||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
|
||||
fused_node.set_op("MapDataset");
|
||||
fused_node.add_input(map_node.input(0));
|
||||
|
||||
@ -72,8 +71,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node,
|
||||
const NodeDef& filter_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef filter_by_component;
|
||||
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent",
|
||||
graph->GetGraph(), &filter_by_component);
|
||||
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(),
|
||||
&filter_by_component);
|
||||
filter_by_component.set_op("FilterByLastComponentDataset");
|
||||
filter_by_component.add_input(fused_map_node.name());
|
||||
|
||||
|
@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
|
||||
const FunctionDef& fused_function,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef fused_node;
|
||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
|
||||
&fused_node);
|
||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
|
||||
fused_node.set_op("MapDataset");
|
||||
fused_node.add_input(parent_map_node.input(0));
|
||||
|
||||
|
@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function,
|
||||
|
||||
NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
|
||||
NodeDef parallel_map = map_node;
|
||||
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(),
|
||||
¶llel_map);
|
||||
parallel_map.set_op("ParallelMapDataset");
|
||||
// TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
|
||||
|
@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef batch_node;
|
||||
batch_node.set_op(old_batch_node.op());
|
||||
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
|
||||
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(),
|
||||
&batch_node);
|
||||
|
||||
// Set the `input_dataset` input argument
|
||||
@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
|
||||
MutableGraphView* graph) {
|
||||
NodeDef map_node;
|
||||
map_node.set_op(old_map_node.op());
|
||||
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
|
||||
&map_node);
|
||||
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node);
|
||||
|
||||
// Set the `input_dataset` input argument
|
||||
map_node.add_input(new_batch_node.name());
|
||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
|
||||
bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
|
||||
if (take_node.op() != "TakeDataset") return false;
|
||||
|
||||
const auto& count_node = *graph.GetNode(take_node.input(1));
|
||||
@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) {
|
||||
return node.attr().at("value").tensor().int64_val(0) == value;
|
||||
}
|
||||
|
||||
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
|
||||
bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) {
|
||||
if (skip_node.op() != "SkipDataset") return false;
|
||||
// We are looking only for skip(0) nodes.
|
||||
return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
|
||||
}
|
||||
|
||||
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
|
||||
bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) {
|
||||
if (repeat_node.op() != "RepeatDataset") return false;
|
||||
// We are looking only for repeat(1) nodes.
|
||||
return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
|
||||
}
|
||||
|
||||
bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) {
|
||||
bool IsPrefetchZero(const NodeDef& prefetch_node,
|
||||
const MutableGraphView& graph) {
|
||||
if (prefetch_node.op() != "PrefetchDataset") return false;
|
||||
// We are looking only for prefetch(0) nodes.
|
||||
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
|
||||
}
|
||||
|
||||
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
|
||||
bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
|
||||
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
|
||||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
|
||||
}
|
||||
|
@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
|
||||
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -31,8 +31,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/functions.h"
|
||||
@ -219,8 +219,7 @@ class FunctionOptimizerContext {
|
||||
: grappler_item_id_(item.id),
|
||||
graph_version_(item.graph.versions().producer()),
|
||||
function_library_(OpRegistry::Global(), item.graph.library()),
|
||||
// GraphView doesn't not modify the graph or the nodes.
|
||||
graph_view_(const_cast<GraphDef*>(&item.graph)) {
|
||||
graph_view_(&item.graph) {
|
||||
InitializeTrulyConstNodes(item);
|
||||
InitializeInlinedFunctions(opt_level, item);
|
||||
InitializeFetchNodes(item);
|
||||
@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// Function specialization might change the number of function outputs, so we
|
||||
// have to process the final optimized graph and update all the node mapping.
|
||||
if (ctx.RequiresOutputMapping()) {
|
||||
GraphView optimized_graph_view(optimized_graph);
|
||||
MutableGraphView optimized_graph_view(optimized_graph);
|
||||
for (const auto& output_mapping : ctx.output_mappings()) {
|
||||
const auto& node_name = output_mapping.first;
|
||||
const auto& mappings = output_mapping.second;
|
||||
@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
int to = mapping.second;
|
||||
|
||||
// Get the output port corresponding to the old output position.
|
||||
GraphView::OutputPort from_port =
|
||||
MutableGraphView::OutputPort from_port =
|
||||
optimized_graph_view.GetOutputPort(node_name, from);
|
||||
|
||||
// Update all input ports that read from old output port.
|
||||
for (GraphView::InputPort to_port :
|
||||
for (MutableGraphView::InputPort to_port :
|
||||
optimized_graph_view.GetFanout(from_port)) {
|
||||
*to_port.node->mutable_input(to_port.port_id) =
|
||||
strings::StrCat(node_name, ":", to);
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -29,8 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
|
||||
@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
|
||||
const NodeMap& node_map, DeviceBase* cpu_device,
|
||||
ResourceMgr* resource_mgr, bool* has_dead_fanout,
|
||||
int* dead_fanout) {
|
||||
Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
const NodeDef& switch_node, const NodeMap& node_map,
|
||||
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
|
||||
bool* has_dead_fanout, int* dead_fanout) {
|
||||
*has_dead_fanout = false;
|
||||
GraphView::InputPort switch_loopcond_port(&switch_node, 1);
|
||||
NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node;
|
||||
const NodeDef* switch_predicate =
|
||||
view.GetRegularFanin(switch_loopcond_port).node;
|
||||
|
||||
// CASE 1: Control is a constant.
|
||||
if (IsConstant(*switch_predicate)) {
|
||||
@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
|
||||
}
|
||||
|
||||
GraphView::InputPort switch_input_port(&switch_node, 0);
|
||||
NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
|
||||
const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
|
||||
|
||||
// CASE 2: Zero-iteration while loop.
|
||||
// We check if its a while loop such that the condition is a simple binary
|
||||
@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
|
||||
// TODO(bsteiner): also rewrite switches as identity. For now we just record
|
||||
// them
|
||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
||||
identity_switches;
|
||||
absl::flat_hash_set<GraphView::OutputPort> identity_switches;
|
||||
|
||||
GraphView view(optimized_graph);
|
||||
MutableGraphView view(optimized_graph);
|
||||
for (const NodeDef& node : optimized_graph->node()) {
|
||||
if (!IsSwitch(node)) {
|
||||
continue;
|
||||
@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
if (!has_dead_fanout) {
|
||||
continue;
|
||||
}
|
||||
GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
|
||||
GraphView::OutputPort dead(&node, dead_fanout);
|
||||
identity_switches.insert(dead);
|
||||
|
||||
SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
|
||||
for (const GraphView::InputPort& port : view.GetFanout(dead)) {
|
||||
SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
|
||||
zombie_inputs;
|
||||
for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
|
||||
if (dead_nodes.find(port.node) == dead_nodes.end()) {
|
||||
zombie_inputs.PushBack(port);
|
||||
}
|
||||
@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
dead_merge_inputs;
|
||||
bool found_node_to_preserve = false;
|
||||
while (!found_node_to_preserve && !zombie_inputs.Empty()) {
|
||||
GraphView::InputPort dead = zombie_inputs.PopBack();
|
||||
MutableGraphView::InputPort dead = zombie_inputs.PopBack();
|
||||
if (nodes_to_preserve.find(dead.node->name()) !=
|
||||
nodes_to_preserve.end()) {
|
||||
found_node_to_preserve = true;
|
||||
@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
found_node_to_preserve = true;
|
||||
break;
|
||||
}
|
||||
GraphView::OutputPort value_index(dead.node, 1);
|
||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
|
||||
index_fanout = view.GetFanout(value_index);
|
||||
MutableGraphView::OutputPort value_index(dead.node, 1);
|
||||
const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
|
||||
view.GetFanout(value_index);
|
||||
if (!index_fanout.empty()) {
|
||||
// The 2nd output (that indicates which input is propagated) is
|
||||
// connected. This never happens in practice, so we'll just skip this
|
||||
@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
}
|
||||
if (fully_dead) {
|
||||
local_dead_nodes.insert(dead.node);
|
||||
for (const GraphView::InputPort& port :
|
||||
for (const MutableGraphView::InputPort& port :
|
||||
view.GetFanouts(*dead.node, true)) {
|
||||
zombie_inputs.PushBack(port);
|
||||
}
|
||||
@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
break;
|
||||
} else {
|
||||
if (local_dead_nodes.insert(dead.node).second) {
|
||||
for (const GraphView::InputPort& dead_fanout :
|
||||
for (const MutableGraphView::InputPort& dead_fanout :
|
||||
view.GetFanouts(*dead.node, true)) {
|
||||
zombie_inputs.PushBack(dead_fanout);
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
|
||||
@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||
|
||||
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
// Look for AddN nodes (and equivalent) and record input names.
|
||||
GraphView view(&item->graph);
|
||||
MutableGraphView view(&item->graph);
|
||||
|
||||
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
|
||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||
@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
const string& input = node->input(i);
|
||||
const string node_name = NodeName(input);
|
||||
NodeDef* node = view.GetNode(node_name);
|
||||
const NodeDef* node = view.GetNode(node_name);
|
||||
input_topo_index.push_back(topo_order.at(node));
|
||||
}
|
||||
int min_input_topo_index = INT_MAX;
|
||||
@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
||||
static bool IsSwappable(const MutableGraphView& graph,
|
||||
MutableGraphView::OutputPort output) {
|
||||
const NodeDef& node = *output.node;
|
||||
// There is no point in swapping out persistent tensors, since the tensor will
|
||||
// continue to use memory.
|
||||
@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
||||
// If placed on the same device, these nodes are just forwarding references
|
||||
// to their input. Therefore they are swappable iff their fanin is swappable
|
||||
// or it resides on a different device.
|
||||
GraphView::InputPort input;
|
||||
MutableGraphView::InputPort input;
|
||||
input.node = output.node;
|
||||
input.port_id = 0;
|
||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||
if (fanin.node->device() == node.device()) {
|
||||
return IsSwappable(graph, fanin);
|
||||
}
|
||||
@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
||||
}
|
||||
|
||||
static NodeDef* FindSwapOutTrigger(
|
||||
const NodeDef* node, int input_id, const GraphView& view,
|
||||
const NodeDef* node, int input_id, const MutableGraphView& view,
|
||||
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
|
||||
execution_times) {
|
||||
// Find the output port that generated the tensor to swap.
|
||||
GraphView::InputPort swap;
|
||||
MutableGraphView::InputPort swap;
|
||||
swap.node = const_cast<NodeDef*>(node);
|
||||
swap.port_id = input_id;
|
||||
GraphView::OutputPort generator = view.GetRegularFanin(swap);
|
||||
MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
|
||||
if (!generator.node) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
|
||||
const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
|
||||
view.GetFanout(generator);
|
||||
NodeDef* trigger = nullptr;
|
||||
Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
|
||||
@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger(
|
||||
return trigger;
|
||||
}
|
||||
|
||||
static bool IsSwappable(GraphView::InputPort input) {
|
||||
static bool IsSwappable(MutableGraphView::InputPort input) {
|
||||
const NodeDef& node = *input.node;
|
||||
|
||||
const OpDef* op_def;
|
||||
@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) {
|
||||
}
|
||||
|
||||
struct MemInfo {
|
||||
GraphView::OutputPort port;
|
||||
MutableGraphView::OutputPort port;
|
||||
int64 memory_used;
|
||||
std::vector<GraphView::InputPort> uses_left;
|
||||
std::vector<MutableGraphView::InputPort> uses_left;
|
||||
double fitness;
|
||||
|
||||
bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
|
||||
@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates(
|
||||
|
||||
std::vector<MemInfo> mem_state;
|
||||
|
||||
GraphView graph(&item->graph);
|
||||
MutableGraphView graph(&item->graph);
|
||||
for (const auto& live_tensor : mem_usage.live_tensors) {
|
||||
if (live_tensor.memory_used <= 1024) {
|
||||
// Don't bother with small tensors.
|
||||
@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates(
|
||||
if (skip_list->find(live_tensor.node) != skip_list->end()) {
|
||||
continue;
|
||||
}
|
||||
GraphView::OutputPort port =
|
||||
MutableGraphView::OutputPort port =
|
||||
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
||||
if (!IsSwappable(graph, port)) {
|
||||
continue;
|
||||
@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates(
|
||||
Costs::Duration allocation_time = live_tensor.allocation_time;
|
||||
Costs::Duration earliest_use(Costs::Duration::infinity());
|
||||
bool valid = true;
|
||||
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
||||
for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
|
||||
// Get execution time.
|
||||
auto it = op_completion_times.find(input.node->name());
|
||||
if (it == op_completion_times.end()) {
|
||||
@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates(
|
||||
// the values do not fit into any integral type.
|
||||
mem_info.fitness =
|
||||
MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
|
||||
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
|
||||
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
|
||||
MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
|
||||
mem_info.fitness = -mem_info.fitness;
|
||||
mem_state.push_back(mem_info);
|
||||
@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates(
|
||||
std::sort(mem_state.begin(), mem_state.end());
|
||||
|
||||
for (const MemInfo& mem_info : mem_state) {
|
||||
for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
|
||||
for (const MutableGraphView::InputPort fanout_to_swap :
|
||||
mem_info.uses_left) {
|
||||
VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
|
||||
<< fanout_to_swap.port_id << " of tensor "
|
||||
<< mem_info.port.node->name() << ":" << mem_info.port.port_id
|
||||
@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
|
||||
for (const auto& node : item->graph.node()) {
|
||||
name_map[node.name()] = &node;
|
||||
}
|
||||
GraphView view(&item->graph);
|
||||
MutableGraphView view(&item->graph);
|
||||
|
||||
bool updated_graph = false;
|
||||
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||
@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
GraphProperties properties(item);
|
||||
bool inferred_properties = false;
|
||||
GraphView graph(optimized_graph);
|
||||
MutableGraphView graph(optimized_graph);
|
||||
|
||||
// The product of all the dimensions in a tensor shape can be expressed more
|
||||
// simply as the size of the tensor.
|
||||
@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
if (!IsShape(node)) {
|
||||
continue;
|
||||
}
|
||||
for (GraphView::InputPort fanout :
|
||||
graph.GetFanout(GraphView::OutputPort(&node, 0))) {
|
||||
for (MutableGraphView::InputPort fanout :
|
||||
graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
|
||||
if (fanout.node->op() != "Prod") {
|
||||
continue;
|
||||
}
|
||||
@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// rewrite the whole expression directly as a Size operation.
|
||||
continue;
|
||||
}
|
||||
const GraphView::OutputPort reduce_indices =
|
||||
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
|
||||
const MutableGraphView::OutputPort reduce_indices =
|
||||
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
|
||||
if (!inferred_properties) {
|
||||
// Infer properties lazily in case they are not needed.
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||
@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// is possible whenever the symbolic dimensions in the numerator and
|
||||
// denominator cancel each other.
|
||||
if (node.op() == "Div") {
|
||||
const GraphView::OutputPort input1 =
|
||||
graph.GetRegularFanin(GraphView::InputPort(&node, 0));
|
||||
const GraphView::OutputPort input2 =
|
||||
graph.GetRegularFanin(GraphView::InputPort(&node, 1));
|
||||
const MutableGraphView::OutputPort input1 =
|
||||
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
|
||||
const MutableGraphView::OutputPort input2 =
|
||||
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
|
||||
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -101,6 +101,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_view",
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
@ -21,8 +21,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
void ReverseDfs(
|
||||
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||
namespace {
|
||||
|
||||
template <typename GraphViewType>
|
||||
void ReverseDfsInternal(
|
||||
const GraphViewType& graph_view, const std::vector<const NodeDef*>& from,
|
||||
const std::function<void(const NodeDef*)>& pre_order,
|
||||
const std::function<void(const NodeDef*)>& post_order,
|
||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||
@ -79,5 +82,25 @@ void ReverseDfs(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ReverseDfs(
|
||||
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||
const std::function<void(const NodeDef*)>& pre_order,
|
||||
const std::function<void(const NodeDef*)>& post_order,
|
||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||
ReverseDfsInternal<GraphView>(graph_view, from, pre_order, post_order,
|
||||
on_back_edge);
|
||||
}
|
||||
|
||||
void ReverseDfs(
|
||||
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||
const std::function<void(const NodeDef*)>& pre_order,
|
||||
const std::function<void(const NodeDef*)>& post_order,
|
||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||
ReverseDfsInternal<MutableGraphView>(graph_view, from, pre_order, post_order,
|
||||
on_back_edge);
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
#include "tensorflow/core/grappler/graph_view.h"
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -34,6 +35,12 @@ void ReverseDfs(
|
||||
const std::function<void(const NodeDef*)>& post_order,
|
||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
|
||||
|
||||
void ReverseDfs(
|
||||
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||
const std::function<void(const NodeDef*)>& pre_order,
|
||||
const std::function<void(const NodeDef*)>& post_order,
|
||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -14,9 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/traversal.h"
|
||||
//#include "tensorflow/core/framework/node_def.pb.h"
|
||||
//#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
//#include "tensorflow/core/platform/protobuf.h"
|
||||
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) {
|
||||
found_back_edge = true;
|
||||
});
|
||||
|
||||
EXPECT_EQ(std::vector<string>({"1", "4", "3", "2", "5", "0"}), pre_order);
|
||||
EXPECT_EQ(std::vector<string>({"4", "5", "2", "3", "1", "0"}), post_order);
|
||||
// Pre/Post order traversals are non deterministic because a node fanin is an
|
||||
// absl::flat_hash_set with non deterministic traversal order.
|
||||
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
|
||||
|
||||
std::set<ValidTraversal> valid_traversals = {
|
||||
// pre_order post_order
|
||||
{{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}},
|
||||
{{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}};
|
||||
|
||||
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
|
||||
EXPECT_FALSE(found_back_edge);
|
||||
}
|
||||
|
||||
@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) {
|
||||
back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
|
||||
});
|
||||
|
||||
EXPECT_EQ(std::vector<string>({"6", "3", "2", "1", "5", "4"}), pre_order);
|
||||
EXPECT_EQ(std::vector<string>({"1", "4", "5", "2", "3", "6"}), post_order);
|
||||
// Pre/Post order traversals are non deterministic because a node fanin is an
|
||||
// absl::flat_hash_set with non deterministic traversal order.
|
||||
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
|
||||
|
||||
std::set<ValidTraversal> valid_traversals = {
|
||||
// pre_order post_order
|
||||
{{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}},
|
||||
{{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}},
|
||||
{{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}};
|
||||
|
||||
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
|
||||
EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user