[TensorFlow PE] Fix static analysis warnings in grappler: std::set<NodeDef*> does not make sense, since ordering of pointers is not deterministic. Switch to using absl::flat_hash_set, which is also faster. Add a method GetOutputsOrderedByNodeName for uses where deterministic ordering is required.

This, combined with inlining the methods of NodeMap also gives a ~4.8% speedup of grappler on a large inference graph.

PiperOrigin-RevId: 304515494
Change-Id: Ia834d2019a142b470333dc8dd4b8151694a6510d
This commit is contained in:
A. Unique TensorFlower 2020-04-02 17:59:31 -07:00 committed by TensorFlower Gardener
parent a51ef16118
commit 4b1c4564e1
8 changed files with 173 additions and 177 deletions

View File

@ -40,6 +40,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",

View File

@ -1764,7 +1764,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
// Update consumers of node to take new_input as input instead.
void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
const auto consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
if (consumer->input(i) == node_name &&
@ -2910,7 +2910,7 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
const auto consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
if (consumer->input(i) == node_name &&
@ -3561,12 +3561,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
// consumers of `node` are already redirected to `simplified_tensor`.
// Re-push the consumers into `nodes_to_simplify` for further
// optimizations.
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
std::vector<NodeDef*> consumers(outputs.begin(), outputs.end());
std::sort(consumers.begin(), consumers.end(),
[](const NodeDef* n1, const NodeDef* n2) {
return n1->name() < n2->name();
});
const std::vector<NodeDef*> consumers =
node_map_->GetOutputsOrderedByNodeName(node->name());
for (NodeDef* consumer : consumers) {
// Update `consumer`'s use of `node` to `input`'s operand.
for (int i = 0; i < consumer->input_size(); ++i) {

View File

@ -217,8 +217,8 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
if (rep == node) {
continue;
}
const std::set<NodeDef*>& tmp = node_map.GetOutputs(node->name());
std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
// Make a copy since we mutate the set below.
const auto fanouts = node_map.GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
// Update consumers of node.
bool updated_fanout = false;

View File

@ -253,7 +253,7 @@ bool ConstantFolding::ForwardInputs(NodeDef* node,
}
}
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
const auto& tmp = node_map_->GetOutputs(node->name());
const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
bool updated_graph = false;
for (int input_idx : inputs_to_forward) {
@ -691,7 +691,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
}
// We make a copy here since we might mutate the set.
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
const auto outputs = node_map_->GetOutputs(node.name());
for (NodeDef* output : outputs) {
for (int k = 0; k < output->input_size(); ++k) {
int port;
@ -1594,13 +1594,8 @@ Status ConstantFolding::FoldGraph(
}
// We need to record a copy of output nodes before FoldNode() modifies it.
// We also need to ensure that the fanout is sorted deterministically.
const std::set<NodeDef*>& outputs = node_map_->GetOutputs(node->name());
std::vector<NodeDef*> fanout(outputs.begin(), outputs.end());
std::sort(fanout.begin(), fanout.end(),
[](const NodeDef* n1, const NodeDef* n2) {
return n1->name() < n2->name();
});
std::vector<NodeDef*> fanout =
node_map_->GetOutputsOrderedByNodeName(node->name());
bool result_too_large = false;
Status s = FoldNode(node, output, &result_too_large);
processed_nodes.insert(node->name());
@ -2449,12 +2444,8 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
SetTensorValue(DT_BOOL, false, &false_t).ok()) {
// Copy the set of consumers of the switch as they will be manipulated
// below.
const auto& consumer_set = node_map_->GetOutputs(node->name());
std::vector<NodeDef*> consumers(consumer_set.begin(), consumer_set.end());
std::sort(consumers.begin(), consumers.end(),
[](const NodeDef* n1, const NodeDef* n2) {
return n1->name() < n2->name();
});
std::vector<NodeDef*> consumers =
node_map_->GetOutputsOrderedByNodeName(node->name());
// Create constant false & true nodes.
NodeDef tmp_false_node;
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));

View File

@ -233,7 +233,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
// Constant nodes with no input control dependency are always executed early,
// so we can prune all their output control dependencies.
if (IsConstant(*node) && node->input_size() == 0) {
const std::set<NodeDef*> output_nodes = node_map_->GetOutputs(node_name);
const auto output_nodes = node_map_->GetOutputs(node_name);
for (NodeDef* fanout : output_nodes) {
bool optimize_fanout = false;
bool data_connection = false;

View File

@ -708,7 +708,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
NodeDef* old_op = ops[op_idx];
// Copy the output node set since we'll be modifying the version
// maintained by NodeMap in the loop.
std::set<NodeDef*> output_nodes = node_map->GetOutputs(old_op->name());
auto output_nodes = node_map->GetOutputs(old_op->name());
VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
<< " outputs. Moving them to the ScopedAllocatorSplit node.";
if (VLOG_IS_ON(2)) {

View File

@ -93,76 +93,6 @@ NodeMap::NodeMap(GraphDef* graph) {
}
}
void NodeMap::RemoveNode(const string& name) {
nodes_.erase(NodeName(name));
outputs_.erase(NodeName(name));
}
NodeDef* NodeMap::GetNode(const string& name) const {
const string node_name = NodeName(name);
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
VLOG(1) << "Node could not be found: " << name;
return nullptr;
}
return it->second;
}
bool NodeMap::NodeExists(const string& name) const {
const string node_name = NodeName(name);
return nodes_.find(node_name) != nodes_.end();
}
const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
auto it = outputs_.find(node_name);
if (it == outputs_.end()) {
return empty_set_;
}
return it->second;
}
void NodeMap::AddNode(const string& node_name, NodeDef* node) {
auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
CHECK(ret.second) << "Pair (" << node_name << "," << node
<< ") is not inserted because the same key already exists.";
}
void NodeMap::AddOutput(const string& node_name, const string& output_name) {
auto output_node = nodes_[NodeName(output_name)];
CHECK(output_node) << "Output node " << output_name
<< " is missing in NodeMap.";
outputs_[node_name].insert(output_node);
}
void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
outputs_[node_name].erase(nodes_[NodeName(output_name)]);
}
void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
const string& new_input_name) {
RemoveOutput(NodeName(old_input_name), node_name);
AddOutput(NodeName(new_input_name), node_name);
}
void NodeMap::RemoveInputs(const string& node_name) {
auto node = nodes_[node_name];
for (const auto& input : node->input()) {
RemoveOutput(NodeName(input), node->name());
}
}
void NodeMap::RemoveOutputs(const string& node_name) {
outputs_.erase(node_name);
}
void NodeMap::UpdateOutput(const string& node_name,
const string& old_output_name,
const string& new_output_name) {
std::set<NodeDef*>& outputs = outputs_[node_name];
outputs.erase(nodes_[NodeName(old_output_name)]);
outputs.insert(nodes_[NodeName(new_output_name)]);
}
string TensorIdToString(const TensorId& tensor_id) {
return tensor_id.index() == 0 ? string(tensor_id.node())
: tensor_id.ToString();

View File

@ -18,11 +18,10 @@ limitations under the License.
#include <functional>
#include <iterator>
#include <set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
@ -42,84 +41,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// A utility class to lookup a node and its outputs by node name.
class NodeMap {
public:
// Note: The NodeMap will store pointers to nodes in graph, which may become
// invalid if graph is changed.
explicit NodeMap(GraphDef* graph);
NodeDef* GetNode(const string& name) const;
bool NodeExists(const string& name) const;
const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
// This method doesn't record the outputs of the added node; the outputs need
// to be explicitly added by the AddOutput method.
void AddNode(const string& name, NodeDef* node);
void RemoveNode(const string& name);
void UpdateInput(const string& node_name, const string& old_input_name,
const string& new_input_name);
void AddOutput(const string& node_name, const string& output_name);
void RemoveInputs(const string& node_name);
void RemoveOutput(const string& node_name, const string& output_name);
void RemoveOutputs(const string& node_name);
void UpdateOutput(const string& node_name, const string& old_output_name,
const string& new_output_name);
private:
const std::set<NodeDef*> empty_set_;
absl::node_hash_map<string, NodeDef*> nodes_;
absl::node_hash_map<string, std::set<NodeDef*>> outputs_;
};
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T, class Hash = std::hash<T>>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
bool Empty() const { return vector_.empty(); }
void Reserve(int64 size) { vector_.reserve(size); }
private:
gtl::FlatSet<T, Hash> set_;
std::vector<T> vector_;
};
// Returns formatted string from TensorId specific to grappler. Specifically,
// for the 0 port (first output), only the node name is returned.
string TensorIdToString(const TensorId& tensor_id);
// Returns formatted string from SafeTensorId specific to grappler.
// Specifically, for the 0 port (first output), only the node name is returned.
string SafeTensorIdToString(const SafeTensorId& tensor_id);
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
// True iff tensor index refers to a control input.
bool IsControlInput(const TensorId& tensor_id);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
// Utilities for manipulating node name and input strings.
// Returns the trailing position number (or zero if no number is present) if
// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
@ -176,6 +98,162 @@ inline int NodePosition(const string& name) {
return position;
}
// A utility class to lookup a node and its outputs by node name.
class NodeMap {
public:
// Note: The NodeMap will store pointers to nodes in graph, which may become
// invalid if graph is changed.
explicit NodeMap(GraphDef* graph);
// Get unordered list of fanouts from node. Notice, that the order is
// non-deterministic.
const absl::flat_hash_set<NodeDef*>& GetOutputs(
const string& node_name) const {
auto it = outputs_.find(node_name);
if (it == outputs_.end()) {
return empty_set_;
}
return it->second;
}
// Get fanouts ordered by name.
std::vector<NodeDef*> GetOutputsOrderedByNodeName(
const string& node_name) const {
std::vector<NodeDef*> result;
auto it = outputs_.find(node_name);
if (it != outputs_.end()) {
const absl::flat_hash_set<NodeDef*>& outputs = it->second;
result.reserve(outputs.size());
result.assign(outputs.begin(), outputs.end());
std::sort(result.begin(), result.end(),
[](const NodeDef* n1, const NodeDef* n2) {
return n1->name() < n2->name();
});
}
return result;
}
// This method doesn't record the outputs of the added node; the outputs need
// to be explicitly added by the AddOutput method.
void AddNode(const string& node_name, NodeDef* node) {
DCHECK(node != nullptr);
auto ret = nodes_.emplace(node_name, node);
DCHECK(ret.second)
<< "Pair (" << node_name << "," << node
<< ") is not inserted because the same key already exists.";
}
void RemoveNode(const string& name) {
nodes_.erase(NodeName(name));
outputs_.erase(NodeName(name));
}
NodeDef* GetNode(const string& name) const {
const string node_name = NodeName(name);
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
VLOG(1) << "Node could not be found: " << name;
return nullptr;
}
return it->second;
}
bool NodeExists(const string& name) const {
const string node_name = NodeName(name);
return nodes_.find(node_name) != nodes_.end();
}
void AddOutput(const string& node_name, const string& output_name) {
auto output_node = nodes_[NodeName(output_name)];
DCHECK(output_node) << "Output node " << output_name
<< " is missing in NodeMap.";
outputs_[node_name].insert(output_node);
}
void RemoveOutput(const string& node_name, const string& output_name) {
outputs_[node_name].erase(nodes_[NodeName(output_name)]);
}
void UpdateInput(const string& node_name, const string& old_input_name,
const string& new_input_name) {
RemoveOutput(NodeName(old_input_name), node_name);
AddOutput(NodeName(new_input_name), node_name);
}
void RemoveInputs(const string& node_name) {
auto node = nodes_[node_name];
for (const auto& input : node->input()) {
RemoveOutput(NodeName(input), node->name());
}
}
void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
void UpdateOutput(const string& node_name, const string& old_output_name,
const string& new_output_name) {
absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
outputs.erase(nodes_[NodeName(old_output_name)]);
outputs.insert(nodes_[NodeName(new_output_name)]);
}
private:
const absl::flat_hash_set<NodeDef*> empty_set_;
absl::node_hash_map<string, NodeDef*> nodes_;
absl::node_hash_map<string, absl::flat_hash_set<NodeDef*>> outputs_;
};
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T, class Hash = std::hash<T>>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
bool Empty() const { return vector_.empty(); }
void Reserve(int64 size) { vector_.reserve(size); }
private:
gtl::FlatSet<T, Hash> set_;
std::vector<T> vector_;
};
// Returns formatted string from TensorId specific to grappler. Specifically,
// for the 0 port (first output), only the node name is returned.
string TensorIdToString(const TensorId& tensor_id);
// Returns formatted string from SafeTensorId specific to grappler.
// Specifically, for the 0 port (first output), only the node name is returned.
string SafeTensorIdToString(const SafeTensorId& tensor_id);
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
// True iff tensor index refers to a control input.
bool IsControlInput(const TensorId& tensor_id);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter);