[Grappler] Preserve constness of the GraphDef in GraphView.
1. Split GrapView into GraphView and MutableGraphView with separate {Input/Output}Port types with different node pointer constness. 2. Properly use GraphView and MutableGraphView in graph properties, and get rid of const_cast. 3. Remove const_cast in function optimizer. 4. Migrate GraphView to absl containers and hash PiperOrigin-RevId: 219488040
This commit is contained in:
parent
92e604060a
commit
3eeaf9f1e1
@ -69,6 +69,9 @@ cc_library(
|
|||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/hash",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,6 +85,8 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ cc_library(
|
|||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
"//tensorflow/core/grappler/utils:topological_sort",
|
"//tensorflow/core/grappler/utils:topological_sort",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core:core_cpu_base",
|
"//tensorflow/core:core_cpu_base",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -30,7 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/grappler/costs/utils.h"
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
@ -456,10 +456,10 @@ class SymbolicShapeRefiner {
|
|||||||
const GraphView& graph,
|
const GraphView& graph,
|
||||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
|
const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
|
||||||
: graph_(graph),
|
: graph_(graph),
|
||||||
function_library_(OpRegistry::Global(), graph.GetGraph()->library()),
|
function_library_(OpRegistry::Global(), graph.graph()->library()),
|
||||||
fed_ports_(fed_ports) {
|
fed_ports_(fed_ports) {
|
||||||
graph_def_version_ = graph.GetGraph()->versions().producer();
|
graph_def_version_ = graph.graph()->versions().producer();
|
||||||
node_to_context_.reserve(graph.GetGraph()->node_size());
|
node_to_context_.reserve(graph.graph()->node_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
const GraphView& graph() const { return graph_; }
|
const GraphView& graph() const { return graph_; }
|
||||||
@ -512,7 +512,7 @@ class SymbolicShapeRefiner {
|
|||||||
// Placeholder with Const) don't affect one in
|
// Placeholder with Const) don't affect one in
|
||||||
// fun_to_grappler_function_item_.
|
// fun_to_grappler_function_item_.
|
||||||
GrapplerFunctionItem grappler_function_item = it->second;
|
GrapplerFunctionItem grappler_function_item = it->second;
|
||||||
GraphView gv(&grappler_function_item.graph);
|
MutableGraphView gv(&grappler_function_item.graph);
|
||||||
|
|
||||||
// Forward shapes from function input nodes to argument nodes.
|
// Forward shapes from function input nodes to argument nodes.
|
||||||
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
|
for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
|
||||||
@ -532,7 +532,7 @@ class SymbolicShapeRefiner {
|
|||||||
"Function inputs should not contain control nodes.");
|
"Function inputs should not contain control nodes.");
|
||||||
}
|
}
|
||||||
|
|
||||||
NodeDef* input_node = graph_.GetNode(node_name);
|
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||||
if (input_node == nullptr) {
|
if (input_node == nullptr) {
|
||||||
return errors::FailedPrecondition(node_name,
|
return errors::FailedPrecondition(node_name,
|
||||||
" was not found in the graph.");
|
" was not found in the graph.");
|
||||||
@ -566,7 +566,7 @@ class SymbolicShapeRefiner {
|
|||||||
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
|
||||||
const string& input = function_node->input(i);
|
const string& input = function_node->input(i);
|
||||||
const string& node_name = NodeName(input);
|
const string& node_name = NodeName(input);
|
||||||
NodeDef* input_node = graph_.GetNode(node_name);
|
const NodeDef* input_node = graph_.GetNode(node_name);
|
||||||
if (IsConstant(*input_node)) {
|
if (IsConstant(*input_node)) {
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
|
ReplaceInputWithConst(*input_node, i, &grappler_function_item));
|
||||||
@ -1441,8 +1441,8 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
ShapeHandle input = in->output(fanin.src.port_id);
|
ShapeHandle input = in->output(fanin.src.port_id);
|
||||||
CHECK_EQ(fanin.tgt.node, node);
|
CHECK_EQ(fanin.dst.node, node);
|
||||||
c->SetInput(fanin.tgt.port_id, input);
|
c->SetInput(fanin.dst.port_id, input);
|
||||||
if (!out_initialized) {
|
if (!out_initialized) {
|
||||||
out_initialized = true;
|
out_initialized = true;
|
||||||
out = input;
|
out = input;
|
||||||
@ -1673,7 +1673,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphView graph_view(const_cast<GraphDef*>(&item_.graph));
|
GraphView graph_view(&item_.graph);
|
||||||
|
|
||||||
// List the resources and the nodes using them. Also collect the Merge nodes,
|
// List the resources and the nodes using them. Also collect the Merge nodes,
|
||||||
// fed nodes, and primary inputs.
|
// fed nodes, and primary inputs.
|
||||||
@ -1725,10 +1725,10 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
|
|||||||
for (const auto& resource : resources) {
|
for (const auto& resource : resources) {
|
||||||
for (const NodeDef* src : resource.second.first) {
|
for (const NodeDef* src : resource.second.first) {
|
||||||
resource_handles[src] = resource.first;
|
resource_handles[src] = resource.first;
|
||||||
for (const NodeDef* tgt : resource.second.second) {
|
for (const NodeDef* dst : resource.second.second) {
|
||||||
// Add control edges from enqueue to dequeue nodes to ensure they are
|
// Add control edges from enqueue to dequeue nodes to ensure they are
|
||||||
// processed in their logical order.
|
// processed in their logical order.
|
||||||
extra_deps.emplace_back(src, tgt);
|
extra_deps.emplace_back(src, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,217 +63,5 @@ int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
|
|||||||
return OpPortIdToArgId(node, op.input_arg(), port_id);
|
return OpPortIdToArgId(node, op.input_arg(), port_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
|
|
||||||
for (int i = 0; i < graph_->node_size(); i++) {
|
|
||||||
auto node = graph_->mutable_node(i);
|
|
||||||
AddUniqueNodeOrDie(node);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (NodeDef& node : *graph_->mutable_node()) {
|
|
||||||
AddFanouts(&node);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
|
|
||||||
auto result = nodes_.emplace(node->name(), node);
|
|
||||||
// Check that the graph doesn't contain multiple nodes with the same name.
|
|
||||||
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphView::AddFanouts(NodeDef* node) {
|
|
||||||
for (int i = 0; i < node->input_size(); ++i) {
|
|
||||||
OutputPort fanin;
|
|
||||||
const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
|
|
||||||
fanin.node = nodes_[fanin_name];
|
|
||||||
|
|
||||||
InputPort input;
|
|
||||||
input.node = node;
|
|
||||||
if (fanin.port_id < 0) {
|
|
||||||
input.port_id = -1;
|
|
||||||
} else {
|
|
||||||
input.port_id = i;
|
|
||||||
num_regular_outputs_[fanin.node] =
|
|
||||||
std::max(num_regular_outputs_[fanin.node], fanin.port_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
fanouts_[fanin].insert(input);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
NodeDef* GraphView::GetNode(const string& node_name) const {
|
|
||||||
auto it = nodes_.find(node_name);
|
|
||||||
if (it == nodes_.end()) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphView::InputPort GraphView::GetInputPort(const string& node_name,
|
|
||||||
int port_id) const {
|
|
||||||
InputPort result;
|
|
||||||
result.node = GetNode(node_name);
|
|
||||||
// TODO(bsteiner): verify that the node has at least port_id input ports
|
|
||||||
result.port_id = port_id;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
GraphView::OutputPort GraphView::GetOutputPort(const string& node_name,
|
|
||||||
int port_id) const {
|
|
||||||
OutputPort result;
|
|
||||||
result.node = GetNode(node_name);
|
|
||||||
// TODO(bsteiner): verify that the node has at least port_id output ports
|
|
||||||
result.port_id = port_id;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
|
|
||||||
GraphView::GetFanout(const GraphView::OutputPort& port) const {
|
|
||||||
auto it = fanouts_.find(port);
|
|
||||||
if (it == fanouts_.end()) {
|
|
||||||
return empty_set_;
|
|
||||||
}
|
|
||||||
return it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
|
||||||
GraphView::GetFanin(const GraphView::InputPort& port) const {
|
|
||||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort> result;
|
|
||||||
if (port.port_id >= 0) {
|
|
||||||
result.insert(GetRegularFanin(port));
|
|
||||||
} else {
|
|
||||||
for (int i = port.node->input_size() - 1; i >= 0; --i) {
|
|
||||||
OutputPort fanin;
|
|
||||||
string fanin_name = ParseNodeName(port.node->input(i), &fanin.port_id);
|
|
||||||
if (fanin.port_id < 0) {
|
|
||||||
auto it = nodes_.find(fanin_name);
|
|
||||||
if (it != nodes_.end()) {
|
|
||||||
fanin.node = it->second;
|
|
||||||
result.insert(fanin);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
const GraphView::OutputPort GraphView::GetRegularFanin(
|
|
||||||
const GraphView::InputPort& port) const {
|
|
||||||
CHECK_LE(0, port.port_id);
|
|
||||||
OutputPort fanin;
|
|
||||||
string fanin_name =
|
|
||||||
ParseNodeName(port.node->input(port.port_id), &fanin.port_id);
|
|
||||||
auto it = nodes_.find(fanin_name);
|
|
||||||
if (it == nodes_.end()) {
|
|
||||||
fanin.node = nullptr;
|
|
||||||
} else {
|
|
||||||
fanin.node = it->second;
|
|
||||||
}
|
|
||||||
return fanin;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<GraphView::InputPort, GraphView::HashPort>
|
|
||||||
GraphView::GetFanouts(const NodeDef& node,
|
|
||||||
bool include_controlled_nodes) const {
|
|
||||||
std::unordered_set<InputPort, HashPort> result;
|
|
||||||
OutputPort port;
|
|
||||||
port.node = const_cast<NodeDef*>(&node);
|
|
||||||
const int first_port_id = include_controlled_nodes ? -1 : 0;
|
|
||||||
auto it = num_regular_outputs_.find(&node);
|
|
||||||
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
|
|
||||||
|
|
||||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
|
||||||
port.port_id = i;
|
|
||||||
auto it = fanouts_.find(port);
|
|
||||||
if (it != fanouts_.end()) {
|
|
||||||
result.insert(it->second.begin(), it->second.end());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
|
||||||
GraphView::GetFanins(const NodeDef& node,
|
|
||||||
bool include_controlling_nodes) const {
|
|
||||||
std::unordered_set<OutputPort, HashPort> result;
|
|
||||||
for (int i = 0; i < node.input_size(); ++i) {
|
|
||||||
OutputPort fanin;
|
|
||||||
string fanin_name = ParseNodeName(node.input(i), &fanin.port_id);
|
|
||||||
if (fanin.port_id < 0) {
|
|
||||||
if (!include_controlling_nodes) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto it = nodes_.find(fanin_name);
|
|
||||||
if (it != nodes_.end()) {
|
|
||||||
fanin.node = it->second;
|
|
||||||
result.insert(fanin);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GraphView::NumFanins(const NodeDef& node,
|
|
||||||
bool include_controlling_nodes) const {
|
|
||||||
int count = 0;
|
|
||||||
for (const string& input : node.input()) {
|
|
||||||
if (!include_controlling_nodes && IsControlInput(input)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
count += 1;
|
|
||||||
}
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
|
|
||||||
GraphView::GetFanoutEdges(const NodeDef& node,
|
|
||||||
bool include_controlled_edges) const {
|
|
||||||
std::unordered_set<Edge, HashEdge> result;
|
|
||||||
OutputPort port;
|
|
||||||
port.node = const_cast<NodeDef*>(&node);
|
|
||||||
const int first_port_id = include_controlled_edges ? -1 : 0;
|
|
||||||
auto it = num_regular_outputs_.find(&node);
|
|
||||||
const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
|
|
||||||
|
|
||||||
for (int i = first_port_id; i <= last_port_id; ++i) {
|
|
||||||
port.port_id = i;
|
|
||||||
auto it = fanouts_.find(port);
|
|
||||||
if (it != fanouts_.end()) {
|
|
||||||
Edge fanout;
|
|
||||||
fanout.src.node = const_cast<NodeDef*>(&node);
|
|
||||||
fanout.src.port_id = i;
|
|
||||||
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
|
|
||||||
fanout.tgt = *itr;
|
|
||||||
result.insert(fanout);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<GraphView::Edge, GraphView::HashEdge>
|
|
||||||
GraphView::GetFaninEdges(const NodeDef& node,
|
|
||||||
bool include_controlling_edges) const {
|
|
||||||
std::unordered_set<Edge, HashEdge> result;
|
|
||||||
for (int i = 0; i < node.input_size(); ++i) {
|
|
||||||
Edge fanin;
|
|
||||||
fanin.tgt.node = const_cast<NodeDef*>(&node);
|
|
||||||
fanin.tgt.port_id = i;
|
|
||||||
string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
|
|
||||||
if (fanin.src.port_id < 0) {
|
|
||||||
if (!include_controlling_edges) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto it = nodes_.find(fanin_name);
|
|
||||||
if (it != nodes_.end()) {
|
|
||||||
fanin.src.node = it->second;
|
|
||||||
result.insert(fanin);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -18,9 +18,16 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/hash/hash.h"
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
|
#include "tensorflow/core/graph/tensor_id.h"
|
||||||
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -36,114 +43,290 @@ namespace grappler {
|
|||||||
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||||
int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
|
||||||
|
|
||||||
// A utility class to simplify the traversal of a GraphDef.
|
namespace internal {
|
||||||
class GraphView {
|
|
||||||
|
// GraphViewInternal is a helper class to simplify graph traversal. It creates
|
||||||
|
// an immutable view of the nodes and edges represented by a GraphDef protocol
|
||||||
|
// buffer.
|
||||||
|
//
|
||||||
|
// There are two public classes implementing GraphViewInternal:
|
||||||
|
//
|
||||||
|
// - GraphView: constructed from the `const GraphDef` and doesn't allow
|
||||||
|
// to mutate underlying graph via input/output ports lookup functions (ports
|
||||||
|
// have const pointers to nodes).
|
||||||
|
//
|
||||||
|
// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
|
||||||
|
// the graph via input/output ports lookup functions (ports have non-const
|
||||||
|
// pointers to nodes), and also have couple additional functions to
|
||||||
|
// add/remove/replace nodes in the graph.
|
||||||
|
//
|
||||||
|
// --------------------------- !!! WARNING !!! ---------------------------------
|
||||||
|
// Removing nodes from the graph outside of MutableGraphView will
|
||||||
|
// lead to segfaults! Guaranteed by absl::string_view!
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
//
|
||||||
|
template <typename GraphDefT, typename NodeDefT>
|
||||||
|
class GraphViewInternal {
|
||||||
public:
|
public:
|
||||||
struct Port {
|
struct Port {
|
||||||
Port() = default;
|
Port() : node(nullptr), port_id(0) {}
|
||||||
Port(NodeDef* n, int port) : node(n), port_id(port) {}
|
Port(NodeDefT* n, int port) : node(n), port_id(port) {}
|
||||||
|
|
||||||
// TODO(prazek): ports should keep the constness of GraphView. The only way
|
|
||||||
// to modify graph through the view should be using MutableGraphView.
|
|
||||||
NodeDef* node = nullptr;
|
|
||||||
int port_id = -1;
|
|
||||||
|
|
||||||
bool operator==(const Port& other) const {
|
bool operator==(const Port& other) const {
|
||||||
return node == other.node && port_id == other.port_id;
|
return node == other.node && port_id == other.port_id;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
struct InputPort : public Port {
|
template <typename H>
|
||||||
InputPort() = default;
|
friend H AbslHashValue(H h, const Port& p) {
|
||||||
InputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
|
return H::combine(std::move(h), p.node, p.port_id);
|
||||||
InputPort(const NodeDef* n, int port_id)
|
}
|
||||||
: Port(const_cast<NodeDef*>(n), port_id) {}
|
|
||||||
};
|
NodeDefT* node;
|
||||||
struct OutputPort : public Port {
|
int port_id;
|
||||||
OutputPort() = default;
|
|
||||||
OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct HashPort {
|
struct InputPort : public Port {
|
||||||
std::size_t operator()(const Port& port) const {
|
using Port::Port;
|
||||||
return reinterpret_cast<std::size_t>(port.node) + port.port_id;
|
};
|
||||||
}
|
|
||||||
|
struct OutputPort : public Port {
|
||||||
|
using Port::Port;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Edge {
|
struct Edge {
|
||||||
OutputPort src;
|
Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
|
||||||
InputPort tgt;
|
|
||||||
|
|
||||||
bool operator==(const Edge& other) const {
|
bool operator==(const Edge& other) const {
|
||||||
return src == other.src && tgt == other.tgt;
|
return src == other.src && dst == other.dst;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
struct HashEdge {
|
template <typename H>
|
||||||
std::size_t operator()(const Edge& edge) const {
|
friend H AbslHashValue(H h, const Edge& e) {
|
||||||
return HashPort()(edge.src) + HashPort()(edge.tgt);
|
return H::combine(std::move(h), e.src, e.dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OutputPort src;
|
||||||
|
InputPort dst;
|
||||||
};
|
};
|
||||||
|
|
||||||
explicit GraphView(GraphDef* graph);
|
GraphDefT* graph() const { return graph_; }
|
||||||
GraphDef* GetGraph() const { return graph_; }
|
|
||||||
NodeDef* GetNode(const string& node_name) const;
|
// Find a node by name or return `nullptr` if it's not in a graph view.
|
||||||
|
NodeDefT* GetNode(absl::string_view node_name) const {
|
||||||
|
return gtl::FindWithDefault(nodes_, node_name, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the specified input port. Note that the special '-1' port_id can be
|
// Get the specified input port. Note that the special '-1' port_id can be
|
||||||
// used to access the controlling nodes (i.e. the nodes connected to node_name
|
// used to access the controlling nodes (i.e. the nodes connected to node_name
|
||||||
// through an incoming control dependency).
|
// through an incoming control dependency).
|
||||||
InputPort GetInputPort(const string& node_name, int port_id) const;
|
InputPort GetInputPort(absl::string_view node_name, int port_id) const {
|
||||||
|
return InputPort(GetNode(node_name), port_id);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the specified output port. Note that the special '-1' port_id can be
|
// Get the specified output port. Note that the special '-1' port_id can be
|
||||||
// used to access the controlled nodes (i.e. the nodes connected to node_name
|
// used to access the controlled nodes (i.e. the nodes connected to node_name
|
||||||
// through an outgoing control dependency).
|
// through an outgoing control dependency).
|
||||||
OutputPort GetOutputPort(const string& node_name, int port_id) const;
|
OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
|
||||||
|
return OutputPort(GetNode(node_name), port_id);
|
||||||
|
}
|
||||||
|
|
||||||
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
|
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
|
||||||
// of an output (resp. input) port.
|
// of an output (resp. input) port.
|
||||||
const std::unordered_set<InputPort, HashPort>& GetFanout(
|
const absl::flat_hash_set<InputPort>& GetFanout(
|
||||||
const OutputPort& port) const;
|
const OutputPort& port) const {
|
||||||
std::unordered_set<OutputPort, HashPort> GetFanin(
|
return gtl::FindWithDefault(fanouts_, port, empty_set_);
|
||||||
const InputPort& port) const;
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
|
||||||
|
if (port.port_id >= 0) return {GetRegularFanin(port)};
|
||||||
|
|
||||||
|
// Collect fanin for the control input.
|
||||||
|
absl::flat_hash_set<OutputPort> result;
|
||||||
|
for (int i = port.node->input_size() - 1; i >= 0; --i) {
|
||||||
|
TensorId tensor_id = ParseTensorName(port.node->input(i));
|
||||||
|
if (tensor_id.index() >= 0) break; // we reached regular inputs
|
||||||
|
|
||||||
|
auto it = nodes_.find(tensor_id.node());
|
||||||
|
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Special case: regular (i.e. non-control) input ports can only have one
|
// Special case: regular (i.e. non-control) input ports can only have one
|
||||||
// fanin.
|
// fanin.
|
||||||
const OutputPort GetRegularFanin(const InputPort& port) const;
|
const OutputPort GetRegularFanin(const InputPort& port) const {
|
||||||
|
DCHECK_GE(port.port_id, 0);
|
||||||
|
if (port.port_id < 0) return OutputPort();
|
||||||
|
|
||||||
// Get all the input (resp. output) ports in the immediate fanout (resp fanin)
|
TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
|
||||||
// of a node. Include the controlling nodes iff include_controlling_nodes is
|
return GetOutputPort(tensor_id.node(), tensor_id.index());
|
||||||
// true.
|
}
|
||||||
std::unordered_set<InputPort, HashPort> GetFanouts(
|
|
||||||
const NodeDef& node, bool include_controlled_nodes) const;
|
// Get all the input (resp. output) ports in the immediate fanout (resp
|
||||||
std::unordered_set<OutputPort, HashPort> GetFanins(
|
// fanin) of a node. Include the controlling nodes iff
|
||||||
const NodeDef& node, bool include_controlling_nodes) const;
|
// include_controlling_nodes is true.
|
||||||
|
absl::flat_hash_set<InputPort> GetFanouts(
|
||||||
|
const NodeDef& node, bool include_controlled_nodes) const {
|
||||||
|
absl::flat_hash_set<InputPort> result;
|
||||||
|
|
||||||
|
OutputPort port;
|
||||||
|
port.node = const_cast<NodeDefT*>(&node);
|
||||||
|
const int first_port_id = include_controlled_nodes ? -1 : 0;
|
||||||
|
const int last_port_id =
|
||||||
|
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||||
|
|
||||||
|
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||||
|
port.port_id = i;
|
||||||
|
auto it = fanouts_.find(port);
|
||||||
|
if (it != fanouts_.end()) {
|
||||||
|
result.insert(it->second.begin(), it->second.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_set<OutputPort> GetFanins(
|
||||||
|
const NodeDef& node, bool include_controlling_nodes) const {
|
||||||
|
absl::flat_hash_set<OutputPort> result;
|
||||||
|
for (int i = 0; i < node.input_size(); ++i) {
|
||||||
|
TensorId tensor_id = ParseTensorName(node.input(i));
|
||||||
|
if (tensor_id.index() < 0 && !include_controlling_nodes) break;
|
||||||
|
|
||||||
|
auto it = nodes_.find(tensor_id.node());
|
||||||
|
if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// Get the number of ports in the immediate fanin of a node. Count the
|
// Get the number of ports in the immediate fanin of a node. Count the
|
||||||
// controlling nodes iff include_controlling_nodes is true.
|
// controlling nodes iff include_controlling_nodes is true.
|
||||||
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const;
|
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const {
|
||||||
|
int count = 0;
|
||||||
|
for (const string& input : node.input()) {
|
||||||
|
if (!include_controlling_nodes && IsControlInput(input)) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
// Get all the edge in the immediate fanout (resp fanin) of a node. Include
|
// Get the number of ports in the immediate fanout of a node. Count the
|
||||||
// the control edges iff include_controlling_edges is true.
|
// controlling nodes iff include_controlling_nodes is true.
|
||||||
std::unordered_set<Edge, HashEdge> GetFanoutEdges(
|
int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const {
|
||||||
const NodeDef& node, bool include_controlled_edges) const;
|
int count = 0;
|
||||||
std::unordered_set<Edge, HashEdge> GetFaninEdges(
|
|
||||||
const NodeDef& node, bool include_controlling_edges) const;
|
OutputPort port;
|
||||||
|
port.node = const_cast<NodeDefT*>(&node);
|
||||||
|
const int first_port_id = include_controlling_nodes ? -1 : 0;
|
||||||
|
const int last_port_id =
|
||||||
|
gtl::FindWithDefault(num_regular_outputs_, port.node, -1);
|
||||||
|
|
||||||
|
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||||
|
port.port_id = i;
|
||||||
|
auto it = fanouts_.find(port);
|
||||||
|
if (it != fanouts_.end()) count += it->second.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
return count;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all the edges in the immediate fanout (resp fanin) of a node.
|
||||||
|
// Include the control edges iff include_controlling_edges is true.
|
||||||
|
absl::flat_hash_set<Edge> GetFanoutEdges(
|
||||||
|
const NodeDef& node, bool include_controlled_edges) const {
|
||||||
|
absl::flat_hash_set<Edge> result;
|
||||||
|
|
||||||
|
OutputPort port;
|
||||||
|
port.node = const_cast<NodeDefT*>(&node);
|
||||||
|
const int first_port_id = include_controlled_edges ? -1 : 0;
|
||||||
|
const int last_port_id =
|
||||||
|
gtl::FindWithDefault(num_regular_outputs_, &node, -1);
|
||||||
|
|
||||||
|
for (int i = first_port_id; i <= last_port_id; ++i) {
|
||||||
|
port.port_id = i;
|
||||||
|
auto it = fanouts_.find(port);
|
||||||
|
if (it != fanouts_.end()) {
|
||||||
|
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
|
||||||
|
result.emplace(/*src*/ OutputPort(const_cast<NodeDefT*>(&node), i),
|
||||||
|
/*dst*/ *itr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_set<Edge> GetFaninEdges(
|
||||||
|
const NodeDef& node, bool include_controlling_edges) const {
|
||||||
|
absl::flat_hash_set<Edge> result;
|
||||||
|
for (int i = 0; i < node.input_size(); ++i) {
|
||||||
|
TensorId tensor_id = ParseTensorName(node.input(i));
|
||||||
|
if (tensor_id.index() < 0 && !include_controlling_edges) break;
|
||||||
|
|
||||||
|
auto it = nodes_.find(tensor_id.node());
|
||||||
|
if (it != nodes_.end()) {
|
||||||
|
result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()),
|
||||||
|
/*dst*/ InputPort(const_cast<NodeDefT*>(&node), i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Add a new `node` to the graph.
|
explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
|
||||||
void AddUniqueNodeOrDie(NodeDef* node);
|
|
||||||
// Add fanout to every `node` input.
|
|
||||||
void AddFanouts(NodeDef* node);
|
|
||||||
std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; }
|
|
||||||
GraphDef* MutableGraph() { return graph_; }
|
|
||||||
|
|
||||||
using FanoutsMapType =
|
void AddUniqueNodeOrDie(NodeDefT* node) {
|
||||||
std::unordered_map<OutputPort, std::unordered_set<InputPort, HashPort>,
|
auto result = nodes_.emplace(node->name(), node);
|
||||||
HashPort>;
|
// TODO(ezhulenev): Replace CHECK with factory method returning
|
||||||
FanoutsMapType* MutableFanouts() { return &fanouts_; }
|
// absl::StatusOr (when available).
|
||||||
|
CHECK(result.second) << "Non unique node name detected: " << node->name();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AddFanouts(NodeDefT* node) {
|
||||||
|
for (int i = 0; i < node->input_size(); ++i) {
|
||||||
|
TensorId tensor_id = ParseTensorName(node->input(i));
|
||||||
|
OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
|
||||||
|
|
||||||
|
if (output.port_id < 0) {
|
||||||
|
fanouts_[output].emplace(node, -1);
|
||||||
|
} else {
|
||||||
|
num_regular_outputs_[output.node] =
|
||||||
|
std::max(num_regular_outputs_[output.node], output.port_id);
|
||||||
|
fanouts_[output].emplace(node, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Access to the mutable internal state for MutableGraphView.
|
||||||
|
absl::flat_hash_map<absl::string_view, NodeDefT*>* mutable_nodes() {
|
||||||
|
return &nodes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>*
|
||||||
|
mutable_fanouts() {
|
||||||
|
return &fanouts_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GraphDef* graph_;
|
GraphDefT* graph_; // must outlive the graph view
|
||||||
std::unordered_map<string, NodeDef*> nodes_;
|
absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
|
||||||
std::unordered_set<InputPort, HashPort> empty_set_;
|
absl::flat_hash_set<InputPort> empty_set_;
|
||||||
FanoutsMapType fanouts_;
|
absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
|
||||||
std::unordered_map<const NodeDef*, int> num_regular_outputs_;
|
std::unordered_map<NodeDefT*, int> num_regular_outputs_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
// Immutable GraphView that keeps the constness of the GraphDef. If you need to
|
||||||
|
// mutate the graph or the nodes via the graph view lookup functions, see
|
||||||
|
// MutableGraphView.
|
||||||
|
class GraphView
|
||||||
|
: public internal::GraphViewInternal<const GraphDef, const NodeDef> {
|
||||||
|
public:
|
||||||
|
explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
|
||||||
|
for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
|
||||||
|
for (const NodeDef& node : graph->node()) AddFanouts(&node);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace grappler
|
} // end namespace grappler
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
#include "tensorflow/core/grappler/graph_view.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/cc/ops/parsing_ops.h"
|
#include "tensorflow/cc/ops/parsing_ops.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
@ -158,19 +160,22 @@ TEST_F(GraphViewTest, BasicGraph) {
|
|||||||
|
|
||||||
const NodeDef* add_node = graph.GetNode("AddN");
|
const NodeDef* add_node = graph.GetNode("AddN");
|
||||||
EXPECT_NE(nullptr, add_node);
|
EXPECT_NE(nullptr, add_node);
|
||||||
string fanouts;
|
|
||||||
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
|
|
||||||
strings::StrAppend(&fanouts,
|
|
||||||
strings::StrCat(fo.node->name(), ":", fo.port_id, " "));
|
|
||||||
}
|
|
||||||
EXPECT_EQ("AddN_2:0 AddN_3:0 ", fanouts);
|
|
||||||
|
|
||||||
string fanins;
|
absl::flat_hash_set<string> fanouts;
|
||||||
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
absl::flat_hash_set<string> expected_fanouts = {"AddN_2:0", "AddN_3:0"};
|
||||||
strings::StrAppend(&fanins,
|
for (const auto& fo : graph.GetFanouts(*add_node, false)) {
|
||||||
strings::StrCat(fi.node->name(), ":", fi.port_id, " "));
|
fanouts.insert(absl::StrCat(fo.node->name(), ":", fo.port_id));
|
||||||
}
|
}
|
||||||
EXPECT_EQ("Square_1:0 Square:0 ", fanins);
|
EXPECT_EQ(graph.NumFanouts(*add_node, false), 2);
|
||||||
|
EXPECT_EQ(fanouts, expected_fanouts);
|
||||||
|
|
||||||
|
absl::flat_hash_set<string> fanins;
|
||||||
|
absl::flat_hash_set<string> expected_fanins = {"Square_1:0", "Square:0"};
|
||||||
|
for (const auto& fi : graph.GetFanins(*add_node, false)) {
|
||||||
|
fanins.insert(absl::StrCat(fi.node->name(), ":", fi.port_id));
|
||||||
|
}
|
||||||
|
EXPECT_EQ(graph.NumFanins(*add_node, false), 2);
|
||||||
|
EXPECT_EQ(fanins, expected_fanins);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GraphViewTest, ControlDependencies) {
|
TEST_F(GraphViewTest, ControlDependencies) {
|
||||||
|
@ -19,8 +19,26 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
|
const absl::flat_hash_set<MutableGraphView::InputPort>&
|
||||||
|
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
|
||||||
|
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
|
||||||
|
port.port_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
|
||||||
|
const GraphView::InputPort& port) const {
|
||||||
|
return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
|
||||||
|
port.port_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
|
||||||
|
const GraphView::InputPort& port) const {
|
||||||
|
return GetRegularFanin(MutableGraphView::InputPort(
|
||||||
|
const_cast<NodeDef*>(port.node), port.port_id));
|
||||||
|
}
|
||||||
|
|
||||||
NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
||||||
auto* node_in_graph = GetGraph()->add_node();
|
auto* node_in_graph = graph()->add_node();
|
||||||
*node_in_graph = std::move(node);
|
*node_in_graph = std::move(node);
|
||||||
|
|
||||||
AddUniqueNodeOrDie(node_in_graph);
|
AddUniqueNodeOrDie(node_in_graph);
|
||||||
@ -31,7 +49,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
|
|||||||
|
|
||||||
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
||||||
const int output_port_id) {
|
const int output_port_id) {
|
||||||
auto* node_in_graph = GetGraph()->add_node();
|
auto* node_in_graph = graph()->add_node();
|
||||||
*node_in_graph = std::move(node);
|
*node_in_graph = std::move(node);
|
||||||
|
|
||||||
AddUniqueNodeOrDie(node_in_graph);
|
AddUniqueNodeOrDie(node_in_graph);
|
||||||
@ -46,8 +64,7 @@ NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node,
|
|||||||
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
||||||
const NodeDef& new_input,
|
const NodeDef& new_input,
|
||||||
const int output_port_id) {
|
const int output_port_id) {
|
||||||
GraphView::OutputPort output_port =
|
OutputPort output_port = GetOutputPort(old_input.name(), output_port_id);
|
||||||
GetOutputPort(old_input.name(), output_port_id);
|
|
||||||
auto fanout = GetFanout(output_port);
|
auto fanout = GetFanout(output_port);
|
||||||
for (auto& input_port : fanout) {
|
for (auto& input_port : fanout) {
|
||||||
input_port.node->set_input(input_port.port_id, new_input.name());
|
input_port.node->set_input(input_port.port_id, new_input.name());
|
||||||
@ -57,17 +74,17 @@ void MutableGraphView::ReplaceInput(const NodeDef& old_input,
|
|||||||
|
|
||||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
||||||
for (const string& node_name_to_delete : nodes_to_delete)
|
for (const string& node_name_to_delete : nodes_to_delete)
|
||||||
RemoveFanouts(MutableNodes()->at(node_name_to_delete));
|
RemoveFanouts(mutable_nodes()->at(node_name_to_delete));
|
||||||
for (const string& node_name_to_delete : nodes_to_delete)
|
for (const string& node_name_to_delete : nodes_to_delete)
|
||||||
MutableNodes()->erase(node_name_to_delete);
|
mutable_nodes()->erase(node_name_to_delete);
|
||||||
EraseNodesFromGraph(nodes_to_delete, GetGraph());
|
EraseNodesFromGraph(nodes_to_delete, graph());
|
||||||
}
|
}
|
||||||
|
|
||||||
void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
||||||
for (int i = 0; i < node->input_size(); ++i) {
|
for (int i = 0; i < node->input_size(); ++i) {
|
||||||
OutputPort fanin;
|
OutputPort fanin;
|
||||||
string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
|
string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
|
||||||
fanin.node = (*MutableNodes())[fanin_name];
|
fanin.node = (*mutable_nodes())[fanin_name];
|
||||||
|
|
||||||
InputPort input;
|
InputPort input;
|
||||||
input.node = node;
|
input.node = node;
|
||||||
@ -76,7 +93,7 @@ void MutableGraphView::RemoveFanouts(NodeDef* node) {
|
|||||||
else
|
else
|
||||||
input.port_id = i;
|
input.port_id = i;
|
||||||
|
|
||||||
(*MutableFanouts())[fanin].erase(input);
|
(*mutable_fanouts())[fanin].erase(input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,11 +24,25 @@ namespace grappler {
|
|||||||
// A utility class to simplify the traversal of a GraphDef that, unlike
|
// A utility class to simplify the traversal of a GraphDef that, unlike
|
||||||
// GraphView, supports updating the graph. Note that you should not modify the
|
// GraphView, supports updating the graph. Note that you should not modify the
|
||||||
// graph separately, because the view will get out of sync.
|
// graph separately, because the view will get out of sync.
|
||||||
class MutableGraphView : public GraphView {
|
|
||||||
public:
|
|
||||||
using GraphView::GraphView;
|
|
||||||
|
|
||||||
GraphDef* GetGraph() { return MutableGraph(); }
|
class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||||
|
public:
|
||||||
|
explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
|
||||||
|
for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
|
||||||
|
for (NodeDef& node : *graph->mutable_node()) AddFanouts(&node);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup fanouts/fanins using immutable ports.
|
||||||
|
using GraphViewInternal::GetFanout;
|
||||||
|
const absl::flat_hash_set<InputPort>& GetFanout(
|
||||||
|
const GraphView::OutputPort& port) const;
|
||||||
|
|
||||||
|
using GraphViewInternal::GetFanin;
|
||||||
|
absl::flat_hash_set<OutputPort> GetFanin(
|
||||||
|
const GraphView::InputPort& port) const;
|
||||||
|
|
||||||
|
using GraphViewInternal::GetRegularFanin;
|
||||||
|
const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
|
||||||
|
|
||||||
// Adds a new node to graph and updates the view.
|
// Adds a new node to graph and updates the view.
|
||||||
NodeDef* AddNode(NodeDef&& node);
|
NodeDef* AddNode(NodeDef&& node);
|
||||||
|
@ -26,7 +26,8 @@ namespace {
|
|||||||
bool FindChildWithName(const MutableGraphView& graph,
|
bool FindChildWithName(const MutableGraphView& graph,
|
||||||
const string& output_port_name,
|
const string& output_port_name,
|
||||||
const string& input_name) {
|
const string& input_name) {
|
||||||
GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0);
|
MutableGraphView::OutputPort output_port =
|
||||||
|
graph.GetOutputPort(output_port_name, 0);
|
||||||
auto fanout = graph.GetFanout(output_port);
|
auto fanout = graph.GetFanout(output_port);
|
||||||
for (auto& input_port : fanout) {
|
for (auto& input_port : fanout) {
|
||||||
if (input_port.node->name() == input_name) return true;
|
if (input_port.node->name() == input_name) return true;
|
||||||
@ -59,10 +60,10 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) {
|
|||||||
GraphDef new_graph = item.graph;
|
GraphDef new_graph = item.graph;
|
||||||
MutableGraphView graph(&new_graph);
|
MutableGraphView graph(&new_graph);
|
||||||
|
|
||||||
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||||
EXPECT_EQ("AddN", input.node->name());
|
EXPECT_EQ("AddN", input.node->name());
|
||||||
EXPECT_EQ(0, input.port_id);
|
EXPECT_EQ(0, input.port_id);
|
||||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||||
EXPECT_EQ("Square", fanin.node->name());
|
EXPECT_EQ("Square", fanin.node->name());
|
||||||
EXPECT_EQ(0, fanin.port_id);
|
EXPECT_EQ(0, fanin.port_id);
|
||||||
|
|
||||||
@ -89,7 +90,7 @@ TEST(MutableGraphViewTest, InsertNodes) {
|
|||||||
GraphDef new_graph = item.graph;
|
GraphDef new_graph = item.graph;
|
||||||
MutableGraphView graph(&new_graph);
|
MutableGraphView graph(&new_graph);
|
||||||
|
|
||||||
GraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
MutableGraphView::InputPort input = graph.GetInputPort("AddN", 0);
|
||||||
|
|
||||||
NodeDef new_node = *input.node;
|
NodeDef new_node = *input.node;
|
||||||
new_node.set_name("new_node");
|
new_node.set_name("new_node");
|
||||||
|
@ -145,8 +145,8 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
@ -422,8 +422,8 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/clusters:cluster",
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
@ -625,12 +625,13 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
"//tensorflow/core/grappler/utils:frame",
|
"//tensorflow/core/grappler/utils:frame",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -663,8 +664,8 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
|
||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
@ -37,7 +37,7 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
|
|||||||
const FunctionDef& fused_function,
|
const FunctionDef& fused_function,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef fused_node;
|
NodeDef fused_node;
|
||||||
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("fused_filter", graph->graph(),
|
||||||
&fused_node);
|
&fused_node);
|
||||||
|
|
||||||
fused_node.set_op("FilterDataset");
|
fused_node.set_op("FilterDataset");
|
||||||
|
@ -72,7 +72,7 @@ NodeDef* AddScalarConstNodeHelper(
|
|||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef node;
|
NodeDef node;
|
||||||
node.set_op(kConstOpName);
|
node.set_op(kConstOpName);
|
||||||
SetUniqueGraphNodeName(kConstOpName, graph->GetGraph(), &node);
|
SetUniqueGraphNodeName(kConstOpName, graph->graph(), &node);
|
||||||
|
|
||||||
(*node.mutable_attr())["dtype"].set_type(dtype);
|
(*node.mutable_attr())["dtype"].set_type(dtype);
|
||||||
std::unique_ptr<tensorflow::TensorProto> tensor =
|
std::unique_ptr<tensorflow::TensorProto> tensor =
|
||||||
@ -92,7 +92,7 @@ NodeDef* AddScalarConstNodeHelper(
|
|||||||
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
|
NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
|
||||||
NodeDef node;
|
NodeDef node;
|
||||||
node.set_op("Placeholder");
|
node.set_op("Placeholder");
|
||||||
SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
|
SetUniqueGraphNodeName(node.op(), graph->graph(), &node);
|
||||||
(*node.mutable_attr())["dtype"].set_type(dtype);
|
(*node.mutable_attr())["dtype"].set_type(dtype);
|
||||||
TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
|
TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
|
||||||
shape->set_unknown_rank(false);
|
shape->set_unknown_rank(false);
|
||||||
@ -107,7 +107,7 @@ NodeDef* AddNode(StringPiece name, StringPiece op,
|
|||||||
if (!name.empty()) {
|
if (!name.empty()) {
|
||||||
node.set_name(string(name));
|
node.set_name(string(name));
|
||||||
} else {
|
} else {
|
||||||
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
|
SetUniqueGraphNodeName(op, graph->graph(), &node);
|
||||||
}
|
}
|
||||||
node.set_op(string(op));
|
node.set_op(string(op));
|
||||||
for (const string& input : inputs) {
|
for (const string& input : inputs) {
|
||||||
@ -228,7 +228,7 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
|
|||||||
|
|
||||||
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
|
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
|
||||||
if (node.input_size() == 0) return nullptr;
|
if (node.input_size() == 0) return nullptr;
|
||||||
GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
|
||||||
return graph.GetRegularFanin(input_port).node;
|
return graph.GetRegularFanin(input_port).node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeBool) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
|
NodeDef* bool_node = AddScalarConstNode<bool>(true, &graph);
|
||||||
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.GetGraph()));
|
EXPECT_TRUE(ContainsGraphNodeWithName(bool_node->name(), *graph.graph()));
|
||||||
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
|
EXPECT_EQ(bool_node->attr().at("value").tensor().bool_val(0), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,8 +49,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeDouble) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
|
NodeDef* double_node = AddScalarConstNode<double>(3.14, &graph);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(ContainsGraphNodeWithName(double_node->name(), *graph.graph()));
|
||||||
ContainsGraphNodeWithName(double_node->name(), *graph.GetGraph()));
|
|
||||||
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
|
EXPECT_FLOAT_EQ(double_node->attr().at("value").tensor().double_val(0), 3.14);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,7 +57,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeFloat) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
|
NodeDef* float_node = AddScalarConstNode<float>(3.14, &graph);
|
||||||
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.GetGraph()));
|
EXPECT_TRUE(ContainsGraphNodeWithName(float_node->name(), *graph.graph()));
|
||||||
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
|
EXPECT_FLOAT_EQ(float_node->attr().at("value").tensor().float_val(0), 3.14);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +65,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
|
NodeDef* int_node = AddScalarConstNode<int>(42, &graph);
|
||||||
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.GetGraph()));
|
EXPECT_TRUE(ContainsGraphNodeWithName(int_node->name(), *graph.graph()));
|
||||||
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
|
EXPECT_EQ(int_node->attr().at("value").tensor().int_val(0), 42);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -74,7 +73,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeInt64) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
|
NodeDef* int64_node = AddScalarConstNode<int64>(42, &graph);
|
||||||
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.GetGraph()));
|
EXPECT_TRUE(ContainsGraphNodeWithName(int64_node->name(), *graph.graph()));
|
||||||
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
|
EXPECT_EQ(int64_node->attr().at("value").tensor().int64_val(0), 42);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,8 +81,7 @@ TEST(GraphUtilsTest, AddScalarConstNodeString) {
|
|||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
|
NodeDef* string_node = AddScalarConstNode<StringPiece>("hello", &graph);
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(ContainsGraphNodeWithName(string_node->name(), *graph.graph()));
|
||||||
ContainsGraphNodeWithName(string_node->name(), *graph.GetGraph()));
|
|
||||||
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
|
EXPECT_EQ(string_node->attr().at("value").tensor().string_val(0), "hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,13 +104,13 @@ TEST(GraphUtilsTest, Compare) {
|
|||||||
TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
|
TEST(GraphUtilsTest, ContainsGraphNodeWithName) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
EXPECT_TRUE(ContainsGraphNodeWithName("A", *graph.graph()));
|
||||||
|
|
||||||
graph.DeleteNodes({"A"});
|
graph.DeleteNodes({"A"});
|
||||||
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.GetGraph()));
|
EXPECT_TRUE(!ContainsGraphNodeWithName("A", *graph.graph()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
|
TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
|
||||||
@ -128,25 +126,25 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
|
|||||||
TEST(GraphUtilsTest, ContainsNodeWithOp) {
|
TEST(GraphUtilsTest, ContainsNodeWithOp) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
EXPECT_TRUE(ContainsNodeWithOp("OpA", *graph.graph()));
|
||||||
|
|
||||||
graph.DeleteNodes({"A"});
|
graph.DeleteNodes({"A"});
|
||||||
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.GetGraph()));
|
EXPECT_TRUE(!ContainsNodeWithOp("OpA", *graph.graph()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, FindGraphNodeWithName) {
|
TEST(GraphUtilsTest, FindGraphNodeWithName) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
EXPECT_NE(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
EXPECT_NE(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||||
|
|
||||||
graph.DeleteNodes({"A"});
|
graph.DeleteNodes({"A"});
|
||||||
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithName("A", *graph.graph()), -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
||||||
@ -162,35 +160,35 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
|
|||||||
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
|
TEST(GraphUtilsTest, FindGraphNodeWithOp) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), 0);
|
||||||
|
|
||||||
graph.DeleteNodes({"B"});
|
graph.DeleteNodes({"B"});
|
||||||
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.graph()), -1);
|
||||||
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
|
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.graph()), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
|
||||||
GraphDef graph_def;
|
GraphDef graph_def;
|
||||||
MutableGraphView graph(&graph_def);
|
MutableGraphView graph(&graph_def);
|
||||||
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
|
EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.graph()), -1);
|
||||||
|
|
||||||
AddNode("A", "OpA", {}, {}, &graph);
|
AddNode("A", "OpA", {}, {}, &graph);
|
||||||
AddNode("B", "OpB", {"A"}, {}, &graph);
|
AddNode("B", "OpB", {"A"}, {}, &graph);
|
||||||
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
AddNode("A2", "OpA", {"B"}, {}, &graph);
|
||||||
std::vector<int> result_indices =
|
std::vector<int> result_indices =
|
||||||
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
|
FindAllGraphNodesWithOp("OpA", *graph.graph());
|
||||||
EXPECT_EQ(result_indices.size(), 2);
|
EXPECT_EQ(result_indices.size(), 2);
|
||||||
EXPECT_EQ(result_indices.at(0), 0);
|
EXPECT_EQ(result_indices.at(0), 0);
|
||||||
EXPECT_EQ(result_indices.at(1), 2);
|
EXPECT_EQ(result_indices.at(1), 2);
|
||||||
|
|
||||||
graph.DeleteNodes({"A2"});
|
graph.DeleteNodes({"A2"});
|
||||||
std::vector<int> result_indices_new =
|
std::vector<int> result_indices_new =
|
||||||
FindAllGraphNodesWithOp("OpA", *graph.GetGraph());
|
FindAllGraphNodesWithOp("OpA", *graph.graph());
|
||||||
EXPECT_EQ(result_indices_new.size(), 1);
|
EXPECT_EQ(result_indices_new.size(), 1);
|
||||||
EXPECT_EQ(result_indices_new.at(0), 0);
|
EXPECT_EQ(result_indices_new.at(0), 0);
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
|
|||||||
const FunctionDef& stateless_function,
|
const FunctionDef& stateless_function,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef stateless_map;
|
NodeDef stateless_map;
|
||||||
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("stateless_map", graph->graph(),
|
||||||
&stateless_map);
|
&stateless_map);
|
||||||
|
|
||||||
stateless_map.set_op("MapDataset");
|
stateless_map.set_op("MapDataset");
|
||||||
@ -68,7 +68,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
|
|||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef random_dataset;
|
NodeDef random_dataset;
|
||||||
random_dataset.set_op("RandomDataset");
|
random_dataset.set_op("RandomDataset");
|
||||||
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->graph(),
|
||||||
&random_dataset);
|
&random_dataset);
|
||||||
|
|
||||||
const auto* seed = graph_utils::AddScalarConstNode<int64>(
|
const auto* seed = graph_utils::AddScalarConstNode<int64>(
|
||||||
@ -89,7 +89,7 @@ NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
|
|||||||
NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
|
NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
|
||||||
NodeDef batch_dataset;
|
NodeDef batch_dataset;
|
||||||
batch_dataset.set_op("BatchDatasetV2");
|
batch_dataset.set_op("BatchDatasetV2");
|
||||||
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->graph(),
|
||||||
&batch_dataset);
|
&batch_dataset);
|
||||||
const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
|
const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
|
||||||
const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
|
const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
|
||||||
@ -112,7 +112,7 @@ NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
|
|||||||
NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
|
NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef zip_node;
|
NodeDef zip_node;
|
||||||
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->graph(),
|
||||||
&zip_node);
|
&zip_node);
|
||||||
|
|
||||||
zip_node.set_op("ZipDataset");
|
zip_node.set_op("ZipDataset");
|
||||||
|
@ -37,8 +37,7 @@ NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
|
|||||||
NodeDef new_node;
|
NodeDef new_node;
|
||||||
new_node.set_op(kInsertOpName);
|
new_node.set_op(kInsertOpName);
|
||||||
graph_utils::SetUniqueGraphNodeName(
|
graph_utils::SetUniqueGraphNodeName(
|
||||||
strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(),
|
strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node);
|
||||||
&new_node);
|
|
||||||
// Set the input of LatencyDataset node as `node`
|
// Set the input of LatencyDataset node as `node`
|
||||||
new_node.add_input(node.name());
|
new_node.add_input(node.name());
|
||||||
|
|
||||||
@ -81,7 +80,8 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// node corresponds to a `Dataset` op.
|
// node corresponds to a `Dataset` op.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0);
|
MutableGraphView::OutputPort output_port =
|
||||||
|
graph.GetOutputPort(node.name(), 0);
|
||||||
auto fanout = graph.GetFanout(output_port);
|
auto fanout = graph.GetFanout(output_port);
|
||||||
if (fanout.size() > 1) {
|
if (fanout.size() > 1) {
|
||||||
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
|
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
|
||||||
|
@ -29,7 +29,7 @@ namespace {
|
|||||||
|
|
||||||
NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) {
|
NodeDef MakeNumaAwareNode(const NodeDef& node, MutableGraphView* graph) {
|
||||||
NodeDef numa_aware_node = node;
|
NodeDef numa_aware_node = node;
|
||||||
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("make_numa_aware", graph->graph(),
|
||||||
&numa_aware_node);
|
&numa_aware_node);
|
||||||
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
|
numa_aware_node.set_op("ExperimentalNumaMapAndBatchDataset");
|
||||||
return numa_aware_node;
|
return numa_aware_node;
|
||||||
|
@ -36,8 +36,7 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
|
|||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef new_node;
|
NodeDef new_node;
|
||||||
new_node.set_op(kFusedOpName);
|
new_node.set_op(kFusedOpName);
|
||||||
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->graph(), &new_node);
|
||||||
&new_node);
|
|
||||||
|
|
||||||
// Set the `input` input argument.
|
// Set the `input` input argument.
|
||||||
new_node.add_input(map_node.input(0));
|
new_node.add_input(map_node.input(0));
|
||||||
|
@ -309,7 +309,7 @@ TEST(MapAndBatchFusionTest, NoChange) {
|
|||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
|
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -37,8 +37,7 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
|
|||||||
const FunctionDef& fused_function,
|
const FunctionDef& fused_function,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef fused_node;
|
NodeDef fused_node;
|
||||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
|
||||||
&fused_node);
|
|
||||||
fused_node.set_op("MapDataset");
|
fused_node.set_op("MapDataset");
|
||||||
fused_node.add_input(map_node.input(0));
|
fused_node.add_input(map_node.input(0));
|
||||||
|
|
||||||
@ -72,8 +71,8 @@ NodeDef MakeFilterByLastComponentNode(const NodeDef& fused_map_node,
|
|||||||
const NodeDef& filter_node,
|
const NodeDef& filter_node,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef filter_by_component;
|
NodeDef filter_by_component;
|
||||||
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent",
|
graph_utils::SetUniqueGraphNodeName("FilterByLastComponent", graph->graph(),
|
||||||
graph->GetGraph(), &filter_by_component);
|
&filter_by_component);
|
||||||
filter_by_component.set_op("FilterByLastComponentDataset");
|
filter_by_component.set_op("FilterByLastComponentDataset");
|
||||||
filter_by_component.add_input(fused_map_node.name());
|
filter_by_component.add_input(fused_map_node.name());
|
||||||
|
|
||||||
|
@ -39,8 +39,7 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
|
|||||||
const FunctionDef& fused_function,
|
const FunctionDef& fused_function,
|
||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef fused_node;
|
NodeDef fused_node;
|
||||||
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("fused_map", graph->graph(), &fused_node);
|
||||||
&fused_node);
|
|
||||||
fused_node.set_op("MapDataset");
|
fused_node.set_op("MapDataset");
|
||||||
fused_node.add_input(parent_map_node.input(0));
|
fused_node.add_input(parent_map_node.input(0));
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ bool CanParallelize(const FunctionDef& function,
|
|||||||
|
|
||||||
NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
|
NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
|
||||||
NodeDef parallel_map = map_node;
|
NodeDef parallel_map = map_node;
|
||||||
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName("parallel_map", graph->graph(),
|
||||||
¶llel_map);
|
¶llel_map);
|
||||||
parallel_map.set_op("ParallelMapDataset");
|
parallel_map.set_op("ParallelMapDataset");
|
||||||
// TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
|
// TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
|
||||||
|
@ -147,7 +147,7 @@ NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
|
|||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef batch_node;
|
NodeDef batch_node;
|
||||||
batch_node.set_op(old_batch_node.op());
|
batch_node.set_op(old_batch_node.op());
|
||||||
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->graph(),
|
||||||
&batch_node);
|
&batch_node);
|
||||||
|
|
||||||
// Set the `input_dataset` input argument
|
// Set the `input_dataset` input argument
|
||||||
@ -187,8 +187,7 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
|
|||||||
MutableGraphView* graph) {
|
MutableGraphView* graph) {
|
||||||
NodeDef map_node;
|
NodeDef map_node;
|
||||||
map_node.set_op(old_map_node.op());
|
map_node.set_op(old_map_node.op());
|
||||||
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
|
graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->graph(), &map_node);
|
||||||
&map_node);
|
|
||||||
|
|
||||||
// Set the `input_dataset` input argument
|
// Set the `input_dataset` input argument
|
||||||
map_node.add_input(new_batch_node.name());
|
map_node.add_input(new_batch_node.name());
|
||||||
|
@ -30,7 +30,7 @@ namespace tensorflow {
|
|||||||
namespace grappler {
|
namespace grappler {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
|
bool IsTakeAll(const NodeDef& take_node, const MutableGraphView& graph) {
|
||||||
if (take_node.op() != "TakeDataset") return false;
|
if (take_node.op() != "TakeDataset") return false;
|
||||||
|
|
||||||
const auto& count_node = *graph.GetNode(take_node.input(1));
|
const auto& count_node = *graph.GetNode(take_node.input(1));
|
||||||
@ -44,25 +44,26 @@ bool IsConstNodeWithValue(const NodeDef& node, int value) {
|
|||||||
return node.attr().at("value").tensor().int64_val(0) == value;
|
return node.attr().at("value").tensor().int64_val(0) == value;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
|
bool IsSkipNone(const NodeDef& skip_node, const MutableGraphView& graph) {
|
||||||
if (skip_node.op() != "SkipDataset") return false;
|
if (skip_node.op() != "SkipDataset") return false;
|
||||||
// We are looking only for skip(0) nodes.
|
// We are looking only for skip(0) nodes.
|
||||||
return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
|
return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
|
bool IsRepeatOne(const NodeDef& repeat_node, const MutableGraphView& graph) {
|
||||||
if (repeat_node.op() != "RepeatDataset") return false;
|
if (repeat_node.op() != "RepeatDataset") return false;
|
||||||
// We are looking only for repeat(1) nodes.
|
// We are looking only for repeat(1) nodes.
|
||||||
return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
|
return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsPrefetchZero(const NodeDef& prefetch_node, const GraphView& graph) {
|
bool IsPrefetchZero(const NodeDef& prefetch_node,
|
||||||
|
const MutableGraphView& graph) {
|
||||||
if (prefetch_node.op() != "PrefetchDataset") return false;
|
if (prefetch_node.op() != "PrefetchDataset") return false;
|
||||||
// We are looking only for prefetch(0) nodes.
|
// We are looking only for prefetch(0) nodes.
|
||||||
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
|
return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
|
bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) {
|
||||||
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
|
return IsTakeAll(node, graph) || IsSkipNone(node, graph) ||
|
||||||
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
|
IsRepeatOne(node, graph) || IsPrefetchZero(node, graph);
|
||||||
}
|
}
|
||||||
|
@ -127,7 +127,7 @@ TEST(ShuffleAndRepeatFusionTest, NoChange) {
|
|||||||
GraphDef output;
|
GraphDef output;
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
EXPECT_TRUE(graph_utils::Compare(*graph.GetGraph(), output));
|
EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
@ -31,8 +31,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op_def.pb.h"
|
#include "tensorflow/core/framework/op_def.pb.h"
|
||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/graph/graph_constructor.h"
|
#include "tensorflow/core/graph/graph_constructor.h"
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
@ -219,8 +219,7 @@ class FunctionOptimizerContext {
|
|||||||
: grappler_item_id_(item.id),
|
: grappler_item_id_(item.id),
|
||||||
graph_version_(item.graph.versions().producer()),
|
graph_version_(item.graph.versions().producer()),
|
||||||
function_library_(OpRegistry::Global(), item.graph.library()),
|
function_library_(OpRegistry::Global(), item.graph.library()),
|
||||||
// GraphView doesn't not modify the graph or the nodes.
|
graph_view_(&item.graph) {
|
||||||
graph_view_(const_cast<GraphDef*>(&item.graph)) {
|
|
||||||
InitializeTrulyConstNodes(item);
|
InitializeTrulyConstNodes(item);
|
||||||
InitializeInlinedFunctions(opt_level, item);
|
InitializeInlinedFunctions(opt_level, item);
|
||||||
InitializeFetchNodes(item);
|
InitializeFetchNodes(item);
|
||||||
@ -1133,7 +1132,7 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// Function specialization might change the number of function outputs, so we
|
// Function specialization might change the number of function outputs, so we
|
||||||
// have to process the final optimized graph and update all the node mapping.
|
// have to process the final optimized graph and update all the node mapping.
|
||||||
if (ctx.RequiresOutputMapping()) {
|
if (ctx.RequiresOutputMapping()) {
|
||||||
GraphView optimized_graph_view(optimized_graph);
|
MutableGraphView optimized_graph_view(optimized_graph);
|
||||||
for (const auto& output_mapping : ctx.output_mappings()) {
|
for (const auto& output_mapping : ctx.output_mappings()) {
|
||||||
const auto& node_name = output_mapping.first;
|
const auto& node_name = output_mapping.first;
|
||||||
const auto& mappings = output_mapping.second;
|
const auto& mappings = output_mapping.second;
|
||||||
@ -1143,11 +1142,11 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
int to = mapping.second;
|
int to = mapping.second;
|
||||||
|
|
||||||
// Get the output port corresponding to the old output position.
|
// Get the output port corresponding to the old output position.
|
||||||
GraphView::OutputPort from_port =
|
MutableGraphView::OutputPort from_port =
|
||||||
optimized_graph_view.GetOutputPort(node_name, from);
|
optimized_graph_view.GetOutputPort(node_name, from);
|
||||||
|
|
||||||
// Update all input ports that read from old output port.
|
// Update all input ports that read from old output port.
|
||||||
for (GraphView::InputPort to_port :
|
for (MutableGraphView::InputPort to_port :
|
||||||
optimized_graph_view.GetFanout(from_port)) {
|
optimized_graph_view.GetFanout(from_port)) {
|
||||||
*to_port.node->mutable_input(to_port.port_id) =
|
*to_port.node->mutable_input(to_port.port_id) =
|
||||||
strings::StrCat(node_name, ":", to);
|
strings::StrCat(node_name, ":", to);
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
@ -29,8 +30,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
|
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
|
||||||
@ -565,13 +566,14 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
|
Status CheckForDeadFanout(const MutableGraphView& view,
|
||||||
const NodeMap& node_map, DeviceBase* cpu_device,
|
const NodeDef& switch_node, const NodeMap& node_map,
|
||||||
ResourceMgr* resource_mgr, bool* has_dead_fanout,
|
DeviceBase* cpu_device, ResourceMgr* resource_mgr,
|
||||||
int* dead_fanout) {
|
bool* has_dead_fanout, int* dead_fanout) {
|
||||||
*has_dead_fanout = false;
|
*has_dead_fanout = false;
|
||||||
GraphView::InputPort switch_loopcond_port(&switch_node, 1);
|
GraphView::InputPort switch_loopcond_port(&switch_node, 1);
|
||||||
NodeDef* switch_predicate = view.GetRegularFanin(switch_loopcond_port).node;
|
const NodeDef* switch_predicate =
|
||||||
|
view.GetRegularFanin(switch_loopcond_port).node;
|
||||||
|
|
||||||
// CASE 1: Control is a constant.
|
// CASE 1: Control is a constant.
|
||||||
if (IsConstant(*switch_predicate)) {
|
if (IsConstant(*switch_predicate)) {
|
||||||
@ -582,7 +584,7 @@ Status CheckForDeadFanout(const GraphView& view, const NodeDef& switch_node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
GraphView::InputPort switch_input_port(&switch_node, 0);
|
GraphView::InputPort switch_input_port(&switch_node, 0);
|
||||||
NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
|
const NodeDef* switch_input = view.GetRegularFanin(switch_input_port).node;
|
||||||
|
|
||||||
// CASE 2: Zero-iteration while loop.
|
// CASE 2: Zero-iteration while loop.
|
||||||
// We check if its a while loop such that the condition is a simple binary
|
// We check if its a while loop such that the condition is a simple binary
|
||||||
@ -707,10 +709,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
|
std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
|
||||||
// TODO(bsteiner): also rewrite switches as identity. For now we just record
|
// TODO(bsteiner): also rewrite switches as identity. For now we just record
|
||||||
// them
|
// them
|
||||||
std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
|
absl::flat_hash_set<GraphView::OutputPort> identity_switches;
|
||||||
identity_switches;
|
|
||||||
|
|
||||||
GraphView view(optimized_graph);
|
MutableGraphView view(optimized_graph);
|
||||||
for (const NodeDef& node : optimized_graph->node()) {
|
for (const NodeDef& node : optimized_graph->node()) {
|
||||||
if (!IsSwitch(node)) {
|
if (!IsSwitch(node)) {
|
||||||
continue;
|
continue;
|
||||||
@ -727,11 +728,12 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
if (!has_dead_fanout) {
|
if (!has_dead_fanout) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
|
GraphView::OutputPort dead(&node, dead_fanout);
|
||||||
identity_switches.insert(dead);
|
identity_switches.insert(dead);
|
||||||
|
|
||||||
SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
|
SetVector<MutableGraphView::InputPort, absl::Hash<MutableGraphView::Port>>
|
||||||
for (const GraphView::InputPort& port : view.GetFanout(dead)) {
|
zombie_inputs;
|
||||||
|
for (const MutableGraphView::InputPort& port : view.GetFanout(dead)) {
|
||||||
if (dead_nodes.find(port.node) == dead_nodes.end()) {
|
if (dead_nodes.find(port.node) == dead_nodes.end()) {
|
||||||
zombie_inputs.PushBack(port);
|
zombie_inputs.PushBack(port);
|
||||||
}
|
}
|
||||||
@ -745,7 +747,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
dead_merge_inputs;
|
dead_merge_inputs;
|
||||||
bool found_node_to_preserve = false;
|
bool found_node_to_preserve = false;
|
||||||
while (!found_node_to_preserve && !zombie_inputs.Empty()) {
|
while (!found_node_to_preserve && !zombie_inputs.Empty()) {
|
||||||
GraphView::InputPort dead = zombie_inputs.PopBack();
|
MutableGraphView::InputPort dead = zombie_inputs.PopBack();
|
||||||
if (nodes_to_preserve.find(dead.node->name()) !=
|
if (nodes_to_preserve.find(dead.node->name()) !=
|
||||||
nodes_to_preserve.end()) {
|
nodes_to_preserve.end()) {
|
||||||
found_node_to_preserve = true;
|
found_node_to_preserve = true;
|
||||||
@ -764,9 +766,9 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
found_node_to_preserve = true;
|
found_node_to_preserve = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
GraphView::OutputPort value_index(dead.node, 1);
|
MutableGraphView::OutputPort value_index(dead.node, 1);
|
||||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
|
const absl::flat_hash_set<MutableGraphView::InputPort>& index_fanout =
|
||||||
index_fanout = view.GetFanout(value_index);
|
view.GetFanout(value_index);
|
||||||
if (!index_fanout.empty()) {
|
if (!index_fanout.empty()) {
|
||||||
// The 2nd output (that indicates which input is propagated) is
|
// The 2nd output (that indicates which input is propagated) is
|
||||||
// connected. This never happens in practice, so we'll just skip this
|
// connected. This never happens in practice, so we'll just skip this
|
||||||
@ -789,7 +791,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
}
|
}
|
||||||
if (fully_dead) {
|
if (fully_dead) {
|
||||||
local_dead_nodes.insert(dead.node);
|
local_dead_nodes.insert(dead.node);
|
||||||
for (const GraphView::InputPort& port :
|
for (const MutableGraphView::InputPort& port :
|
||||||
view.GetFanouts(*dead.node, true)) {
|
view.GetFanouts(*dead.node, true)) {
|
||||||
zombie_inputs.PushBack(port);
|
zombie_inputs.PushBack(port);
|
||||||
}
|
}
|
||||||
@ -800,7 +802,7 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
if (local_dead_nodes.insert(dead.node).second) {
|
if (local_dead_nodes.insert(dead.node).second) {
|
||||||
for (const GraphView::InputPort& dead_fanout :
|
for (const MutableGraphView::InputPort& dead_fanout :
|
||||||
view.GetFanouts(*dead.node, true)) {
|
view.GetFanouts(*dead.node, true)) {
|
||||||
zombie_inputs.PushBack(dead_fanout);
|
zombie_inputs.PushBack(dead_fanout);
|
||||||
}
|
}
|
||||||
|
@ -30,8 +30,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
#include "tensorflow/core/grappler/costs/graph_memory.h"
|
||||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
#include "tensorflow/core/grappler/costs/utils.h"
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
|
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
|
||||||
@ -497,7 +497,7 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
|||||||
|
|
||||||
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||||
// Look for AddN nodes (and equivalent) and record input names.
|
// Look for AddN nodes (and equivalent) and record input names.
|
||||||
GraphView view(&item->graph);
|
MutableGraphView view(&item->graph);
|
||||||
|
|
||||||
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
|
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
|
||||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||||
@ -592,7 +592,7 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
|||||||
for (int i = 0; i < node->input_size(); ++i) {
|
for (int i = 0; i < node->input_size(); ++i) {
|
||||||
const string& input = node->input(i);
|
const string& input = node->input(i);
|
||||||
const string node_name = NodeName(input);
|
const string node_name = NodeName(input);
|
||||||
NodeDef* node = view.GetNode(node_name);
|
const NodeDef* node = view.GetNode(node_name);
|
||||||
input_topo_index.push_back(topo_order.at(node));
|
input_topo_index.push_back(topo_order.at(node));
|
||||||
}
|
}
|
||||||
int min_input_topo_index = INT_MAX;
|
int min_input_topo_index = INT_MAX;
|
||||||
@ -834,7 +834,8 @@ static const NodeDef* FindSwapInTrigger(
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
static bool IsSwappable(const MutableGraphView& graph,
|
||||||
|
MutableGraphView::OutputPort output) {
|
||||||
const NodeDef& node = *output.node;
|
const NodeDef& node = *output.node;
|
||||||
// There is no point in swapping out persistent tensors, since the tensor will
|
// There is no point in swapping out persistent tensors, since the tensor will
|
||||||
// continue to use memory.
|
// continue to use memory.
|
||||||
@ -860,10 +861,10 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
|||||||
// If placed on the same device, these nodes are just forwarding references
|
// If placed on the same device, these nodes are just forwarding references
|
||||||
// to their input. Therefore they are swappable iff their fanin is swappable
|
// to their input. Therefore they are swappable iff their fanin is swappable
|
||||||
// or it resides on a different device.
|
// or it resides on a different device.
|
||||||
GraphView::InputPort input;
|
MutableGraphView::InputPort input;
|
||||||
input.node = output.node;
|
input.node = output.node;
|
||||||
input.port_id = 0;
|
input.port_id = 0;
|
||||||
GraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
MutableGraphView::OutputPort fanin = graph.GetRegularFanin(input);
|
||||||
if (fanin.node->device() == node.device()) {
|
if (fanin.node->device() == node.device()) {
|
||||||
return IsSwappable(graph, fanin);
|
return IsSwappable(graph, fanin);
|
||||||
}
|
}
|
||||||
@ -872,19 +873,19 @@ static bool IsSwappable(const GraphView& graph, GraphView::OutputPort output) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static NodeDef* FindSwapOutTrigger(
|
static NodeDef* FindSwapOutTrigger(
|
||||||
const NodeDef* node, int input_id, const GraphView& view,
|
const NodeDef* node, int input_id, const MutableGraphView& view,
|
||||||
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
|
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
|
||||||
execution_times) {
|
execution_times) {
|
||||||
// Find the output port that generated the tensor to swap.
|
// Find the output port that generated the tensor to swap.
|
||||||
GraphView::InputPort swap;
|
MutableGraphView::InputPort swap;
|
||||||
swap.node = const_cast<NodeDef*>(node);
|
swap.node = const_cast<NodeDef*>(node);
|
||||||
swap.port_id = input_id;
|
swap.port_id = input_id;
|
||||||
GraphView::OutputPort generator = view.GetRegularFanin(swap);
|
MutableGraphView::OutputPort generator = view.GetRegularFanin(swap);
|
||||||
if (!generator.node) {
|
if (!generator.node) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_set<GraphView::InputPort, GraphView::HashPort>& fanout =
|
const absl::flat_hash_set<MutableGraphView::InputPort>& fanout =
|
||||||
view.GetFanout(generator);
|
view.GetFanout(generator);
|
||||||
NodeDef* trigger = nullptr;
|
NodeDef* trigger = nullptr;
|
||||||
Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
|
Costs::NanoSeconds earliest_fanout(Costs::NanoSeconds::infinity());
|
||||||
@ -903,7 +904,7 @@ static NodeDef* FindSwapOutTrigger(
|
|||||||
return trigger;
|
return trigger;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool IsSwappable(GraphView::InputPort input) {
|
static bool IsSwappable(MutableGraphView::InputPort input) {
|
||||||
const NodeDef& node = *input.node;
|
const NodeDef& node = *input.node;
|
||||||
|
|
||||||
const OpDef* op_def;
|
const OpDef* op_def;
|
||||||
@ -920,9 +921,9 @@ static bool IsSwappable(GraphView::InputPort input) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct MemInfo {
|
struct MemInfo {
|
||||||
GraphView::OutputPort port;
|
MutableGraphView::OutputPort port;
|
||||||
int64 memory_used;
|
int64 memory_used;
|
||||||
std::vector<GraphView::InputPort> uses_left;
|
std::vector<MutableGraphView::InputPort> uses_left;
|
||||||
double fitness;
|
double fitness;
|
||||||
|
|
||||||
bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
|
bool operator<(const MemInfo& other) const { return fitness < other.fitness; }
|
||||||
@ -993,7 +994,7 @@ static bool IdentifySwappingCandidates(
|
|||||||
|
|
||||||
std::vector<MemInfo> mem_state;
|
std::vector<MemInfo> mem_state;
|
||||||
|
|
||||||
GraphView graph(&item->graph);
|
MutableGraphView graph(&item->graph);
|
||||||
for (const auto& live_tensor : mem_usage.live_tensors) {
|
for (const auto& live_tensor : mem_usage.live_tensors) {
|
||||||
if (live_tensor.memory_used <= 1024) {
|
if (live_tensor.memory_used <= 1024) {
|
||||||
// Don't bother with small tensors.
|
// Don't bother with small tensors.
|
||||||
@ -1009,7 +1010,7 @@ static bool IdentifySwappingCandidates(
|
|||||||
if (skip_list->find(live_tensor.node) != skip_list->end()) {
|
if (skip_list->find(live_tensor.node) != skip_list->end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
GraphView::OutputPort port =
|
MutableGraphView::OutputPort port =
|
||||||
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
graph.GetOutputPort(live_tensor.node, live_tensor.output_id);
|
||||||
if (!IsSwappable(graph, port)) {
|
if (!IsSwappable(graph, port)) {
|
||||||
continue;
|
continue;
|
||||||
@ -1020,7 +1021,7 @@ static bool IdentifySwappingCandidates(
|
|||||||
Costs::Duration allocation_time = live_tensor.allocation_time;
|
Costs::Duration allocation_time = live_tensor.allocation_time;
|
||||||
Costs::Duration earliest_use(Costs::Duration::infinity());
|
Costs::Duration earliest_use(Costs::Duration::infinity());
|
||||||
bool valid = true;
|
bool valid = true;
|
||||||
for (GraphView::InputPort input : graph.GetFanout(port)) {
|
for (MutableGraphView::InputPort input : graph.GetFanout(port)) {
|
||||||
// Get execution time.
|
// Get execution time.
|
||||||
auto it = op_completion_times.find(input.node->name());
|
auto it = op_completion_times.find(input.node->name());
|
||||||
if (it == op_completion_times.end()) {
|
if (it == op_completion_times.end()) {
|
||||||
@ -1062,7 +1063,7 @@ static bool IdentifySwappingCandidates(
|
|||||||
// the values do not fit into any integral type.
|
// the values do not fit into any integral type.
|
||||||
mem_info.fitness =
|
mem_info.fitness =
|
||||||
MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
|
MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
|
||||||
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
|
MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
|
||||||
MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
|
MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
|
||||||
mem_info.fitness = -mem_info.fitness;
|
mem_info.fitness = -mem_info.fitness;
|
||||||
mem_state.push_back(mem_info);
|
mem_state.push_back(mem_info);
|
||||||
@ -1073,7 +1074,8 @@ static bool IdentifySwappingCandidates(
|
|||||||
std::sort(mem_state.begin(), mem_state.end());
|
std::sort(mem_state.begin(), mem_state.end());
|
||||||
|
|
||||||
for (const MemInfo& mem_info : mem_state) {
|
for (const MemInfo& mem_info : mem_state) {
|
||||||
for (const GraphView::InputPort fanout_to_swap : mem_info.uses_left) {
|
for (const MutableGraphView::InputPort fanout_to_swap :
|
||||||
|
mem_info.uses_left) {
|
||||||
VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
|
VLOG(1) << "Will swap fanout " << fanout_to_swap.node->name() << ":"
|
||||||
<< fanout_to_swap.port_id << " of tensor "
|
<< fanout_to_swap.port_id << " of tensor "
|
||||||
<< mem_info.port.node->name() << ":" << mem_info.port.port_id
|
<< mem_info.port.node->name() << ":" << mem_info.port.port_id
|
||||||
@ -1150,7 +1152,7 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level,
|
|||||||
for (const auto& node : item->graph.node()) {
|
for (const auto& node : item->graph.node()) {
|
||||||
name_map[node.name()] = &node;
|
name_map[node.name()] = &node;
|
||||||
}
|
}
|
||||||
GraphView view(&item->graph);
|
MutableGraphView view(&item->graph);
|
||||||
|
|
||||||
bool updated_graph = false;
|
bool updated_graph = false;
|
||||||
|
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
|
||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
|
||||||
@ -34,7 +34,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
GraphProperties properties(item);
|
GraphProperties properties(item);
|
||||||
bool inferred_properties = false;
|
bool inferred_properties = false;
|
||||||
GraphView graph(optimized_graph);
|
MutableGraphView graph(optimized_graph);
|
||||||
|
|
||||||
// The product of all the dimensions in a tensor shape can be expressed more
|
// The product of all the dimensions in a tensor shape can be expressed more
|
||||||
// simply as the size of the tensor.
|
// simply as the size of the tensor.
|
||||||
@ -42,8 +42,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
if (!IsShape(node)) {
|
if (!IsShape(node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (GraphView::InputPort fanout :
|
for (MutableGraphView::InputPort fanout :
|
||||||
graph.GetFanout(GraphView::OutputPort(&node, 0))) {
|
graph.GetFanout(MutableGraphView::OutputPort(&node, 0))) {
|
||||||
if (fanout.node->op() != "Prod") {
|
if (fanout.node->op() != "Prod") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -53,8 +53,8 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// rewrite the whole expression directly as a Size operation.
|
// rewrite the whole expression directly as a Size operation.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const GraphView::OutputPort reduce_indices =
|
const MutableGraphView::OutputPort reduce_indices =
|
||||||
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
|
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
|
||||||
if (!inferred_properties) {
|
if (!inferred_properties) {
|
||||||
// Infer properties lazily in case they are not needed.
|
// Infer properties lazily in case they are not needed.
|
||||||
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||||
@ -90,10 +90,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// is possible whenever the symbolic dimensions in the numerator and
|
// is possible whenever the symbolic dimensions in the numerator and
|
||||||
// denominator cancel each other.
|
// denominator cancel each other.
|
||||||
if (node.op() == "Div") {
|
if (node.op() == "Div") {
|
||||||
const GraphView::OutputPort input1 =
|
const MutableGraphView::OutputPort input1 =
|
||||||
graph.GetRegularFanin(GraphView::InputPort(&node, 0));
|
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 0));
|
||||||
const GraphView::OutputPort input2 =
|
const MutableGraphView::OutputPort input2 =
|
||||||
graph.GetRegularFanin(GraphView::InputPort(&node, 1));
|
graph.GetRegularFanin(MutableGraphView::InputPort(&node, 1));
|
||||||
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
|
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -101,6 +101,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/grappler:graph_view",
|
"//tensorflow/core/grappler:graph_view",
|
||||||
|
"//tensorflow/core/grappler:mutable_graph_view",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -21,8 +21,11 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
void ReverseDfs(
|
namespace {
|
||||||
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
|
|
||||||
|
template <typename GraphViewType>
|
||||||
|
void ReverseDfsInternal(
|
||||||
|
const GraphViewType& graph_view, const std::vector<const NodeDef*>& from,
|
||||||
const std::function<void(const NodeDef*)>& pre_order,
|
const std::function<void(const NodeDef*)>& pre_order,
|
||||||
const std::function<void(const NodeDef*)>& post_order,
|
const std::function<void(const NodeDef*)>& post_order,
|
||||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||||
@ -79,5 +82,25 @@ void ReverseDfs(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void ReverseDfs(
|
||||||
|
const GraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||||
|
const std::function<void(const NodeDef*)>& pre_order,
|
||||||
|
const std::function<void(const NodeDef*)>& post_order,
|
||||||
|
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||||
|
ReverseDfsInternal<GraphView>(graph_view, from, pre_order, post_order,
|
||||||
|
on_back_edge);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ReverseDfs(
|
||||||
|
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||||
|
const std::function<void(const NodeDef*)>& pre_order,
|
||||||
|
const std::function<void(const NodeDef*)>& post_order,
|
||||||
|
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge) {
|
||||||
|
ReverseDfsInternal<MutableGraphView>(graph_view, from, pre_order, post_order,
|
||||||
|
on_back_edge);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include "tensorflow/core/grappler/graph_view.h"
|
#include "tensorflow/core/grappler/graph_view.h"
|
||||||
|
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
@ -34,6 +35,12 @@ void ReverseDfs(
|
|||||||
const std::function<void(const NodeDef*)>& post_order,
|
const std::function<void(const NodeDef*)>& post_order,
|
||||||
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
|
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
|
||||||
|
|
||||||
|
void ReverseDfs(
|
||||||
|
const MutableGraphView& graph_view, const std::vector<const NodeDef*>& from,
|
||||||
|
const std::function<void(const NodeDef*)>& pre_order,
|
||||||
|
const std::function<void(const NodeDef*)>& post_order,
|
||||||
|
const std::function<void(const NodeDef*, const NodeDef*)>& on_back_edge);
|
||||||
|
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -14,9 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/utils/traversal.h"
|
#include "tensorflow/core/grappler/utils/traversal.h"
|
||||||
//#include "tensorflow/core/framework/node_def.pb.h"
|
|
||||||
//#include "tensorflow/core/lib/core/status_test_util.h"
|
|
||||||
//#include "tensorflow/core/platform/protobuf.h"
|
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -65,8 +63,16 @@ TEST_F(TraversalTest, ReverseDfsNoLoop) {
|
|||||||
found_back_edge = true;
|
found_back_edge = true;
|
||||||
});
|
});
|
||||||
|
|
||||||
EXPECT_EQ(std::vector<string>({"1", "4", "3", "2", "5", "0"}), pre_order);
|
// Pre/Post order traversals are non deterministic because a node fanin is an
|
||||||
EXPECT_EQ(std::vector<string>({"4", "5", "2", "3", "1", "0"}), post_order);
|
// absl::flat_hash_set with non deterministic traversal order.
|
||||||
|
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
|
||||||
|
|
||||||
|
std::set<ValidTraversal> valid_traversals = {
|
||||||
|
// pre_order post_order
|
||||||
|
{{"1", "4", "3", "2", "5", "0"}, {"4", "5", "2", "3", "1", "0"}},
|
||||||
|
{{"1", "3", "2", "5", "4", "0"}, {"5", "2", "3", "4", "1", "0"}}};
|
||||||
|
|
||||||
|
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
|
||||||
EXPECT_FALSE(found_back_edge);
|
EXPECT_FALSE(found_back_edge);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,8 +98,17 @@ TEST_F(TraversalTest, ReverseDfsWithLoop) {
|
|||||||
back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
|
back_edges.push_back(strings::StrCat(src->name(), "->", dst->name()));
|
||||||
});
|
});
|
||||||
|
|
||||||
EXPECT_EQ(std::vector<string>({"6", "3", "2", "1", "5", "4"}), pre_order);
|
// Pre/Post order traversals are non deterministic because a node fanin is an
|
||||||
EXPECT_EQ(std::vector<string>({"1", "4", "5", "2", "3", "6"}), post_order);
|
// absl::flat_hash_set with non deterministic traversal order.
|
||||||
|
using ValidTraversal = std::pair<std::vector<string>, std::vector<string>>;
|
||||||
|
|
||||||
|
std::set<ValidTraversal> valid_traversals = {
|
||||||
|
// pre_order post_order
|
||||||
|
{{"6", "3", "2", "4", "5", "1"}, {"5", "4", "1", "2", "3", "6"}},
|
||||||
|
{{"6", "3", "2", "1", "5", "4"}, {"1", "4", "5", "2", "3", "6"}},
|
||||||
|
{{"6", "3", "2", "5", "4", "1"}, {"4", "5", "1", "2", "3", "6"}}};
|
||||||
|
|
||||||
|
EXPECT_EQ(valid_traversals.count({pre_order, post_order}), 1);
|
||||||
EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
|
EXPECT_EQ(std::vector<string>({"4->3"}), back_edges);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user