[TensorFlow PE] Fix static analysis warnings in grappler: std::set<NodeDef*> does not make sense, since ordering of pointers is not deterministic. Switch to using absl::flat_hash_set, which is also faster. Add a method GetOutputsOrderedByNodeName for uses where deterministic ordering is required.
This, combined with inlining the methods of NodeMap also gives a ~4.8% speedup of grappler on a large inference graph. PiperOrigin-RevId: 304515494 Change-Id: Ia834d2019a142b470333dc8dd4b8151694a6510d
This commit is contained in:
parent
a51ef16118
commit
4b1c4564e1
@ -40,6 +40,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -1764,7 +1764,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
|
||||
// Update consumers of node to take new_input as input instead.
|
||||
void UpdateConsumers(NodeDef* node, const string& new_input) {
|
||||
const string& node_name = node->name();
|
||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||
const auto consumers = ctx().node_map->GetOutputs(node_name);
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (consumer->input(i) == node_name &&
|
||||
@ -2910,7 +2910,7 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
|
||||
|
||||
void UpdateConsumers(NodeDef* node, const string& new_input) {
|
||||
const string& node_name = node->name();
|
||||
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
|
||||
const auto consumers = ctx().node_map->GetOutputs(node_name);
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (consumer->input(i) == node_name &&
|
||||
@ -3561,12 +3561,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
// consumers of `node` are already redirected to `simplified_tensor`.
|
||||
// Re-push the consumers into `nodes_to_simplify` for further
|
||||
// optimizations.
|
||||
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
|
||||
std::vector<NodeDef*> consumers(outputs.begin(), outputs.end());
|
||||
std::sort(consumers.begin(), consumers.end(),
|
||||
[](const NodeDef* n1, const NodeDef* n2) {
|
||||
return n1->name() < n2->name();
|
||||
});
|
||||
const std::vector<NodeDef*> consumers =
|
||||
node_map_->GetOutputsOrderedByNodeName(node->name());
|
||||
for (NodeDef* consumer : consumers) {
|
||||
// Update `consumer`'s use of `node` to `input`'s operand.
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
|
@ -217,8 +217,8 @@ Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
|
||||
if (rep == node) {
|
||||
continue;
|
||||
}
|
||||
const std::set<NodeDef*>& tmp = node_map.GetOutputs(node->name());
|
||||
std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
|
||||
// Make a copy since we mutate the set below.
|
||||
const auto fanouts = node_map.GetOutputs(node->name());
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
// Update consumers of node.
|
||||
bool updated_fanout = false;
|
||||
|
@ -253,7 +253,7 @@ bool ConstantFolding::ForwardInputs(NodeDef* node,
|
||||
}
|
||||
}
|
||||
|
||||
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
|
||||
const auto& tmp = node_map_->GetOutputs(node->name());
|
||||
const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
|
||||
bool updated_graph = false;
|
||||
for (int input_idx : inputs_to_forward) {
|
||||
@ -691,7 +691,7 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
|
||||
}
|
||||
|
||||
// We make a copy here since we might mutate the set.
|
||||
const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
|
||||
const auto outputs = node_map_->GetOutputs(node.name());
|
||||
for (NodeDef* output : outputs) {
|
||||
for (int k = 0; k < output->input_size(); ++k) {
|
||||
int port;
|
||||
@ -1594,13 +1594,8 @@ Status ConstantFolding::FoldGraph(
|
||||
}
|
||||
// We need to record a copy of output nodes before FoldNode() modifies it.
|
||||
// We also need to ensure that the fanout is sorted deterministically.
|
||||
const std::set<NodeDef*>& outputs = node_map_->GetOutputs(node->name());
|
||||
std::vector<NodeDef*> fanout(outputs.begin(), outputs.end());
|
||||
std::sort(fanout.begin(), fanout.end(),
|
||||
[](const NodeDef* n1, const NodeDef* n2) {
|
||||
return n1->name() < n2->name();
|
||||
});
|
||||
|
||||
std::vector<NodeDef*> fanout =
|
||||
node_map_->GetOutputsOrderedByNodeName(node->name());
|
||||
bool result_too_large = false;
|
||||
Status s = FoldNode(node, output, &result_too_large);
|
||||
processed_nodes.insert(node->name());
|
||||
@ -2449,12 +2444,8 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
||||
SetTensorValue(DT_BOOL, false, &false_t).ok()) {
|
||||
// Copy the set of consumers of the switch as they will be manipulated
|
||||
// below.
|
||||
const auto& consumer_set = node_map_->GetOutputs(node->name());
|
||||
std::vector<NodeDef*> consumers(consumer_set.begin(), consumer_set.end());
|
||||
std::sort(consumers.begin(), consumers.end(),
|
||||
[](const NodeDef* n1, const NodeDef* n2) {
|
||||
return n1->name() < n2->name();
|
||||
});
|
||||
std::vector<NodeDef*> consumers =
|
||||
node_map_->GetOutputsOrderedByNodeName(node->name());
|
||||
// Create constant false & true nodes.
|
||||
NodeDef tmp_false_node;
|
||||
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
|
||||
|
@ -233,7 +233,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
// Constant nodes with no input control dependency are always executed early,
|
||||
// so we can prune all their output control dependencies.
|
||||
if (IsConstant(*node) && node->input_size() == 0) {
|
||||
const std::set<NodeDef*> output_nodes = node_map_->GetOutputs(node_name);
|
||||
const auto output_nodes = node_map_->GetOutputs(node_name);
|
||||
for (NodeDef* fanout : output_nodes) {
|
||||
bool optimize_fanout = false;
|
||||
bool data_connection = false;
|
||||
|
@ -708,7 +708,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
|
||||
NodeDef* old_op = ops[op_idx];
|
||||
// Copy the output node set since we'll be modifying the version
|
||||
// maintained by NodeMap in the loop.
|
||||
std::set<NodeDef*> output_nodes = node_map->GetOutputs(old_op->name());
|
||||
auto output_nodes = node_map->GetOutputs(old_op->name());
|
||||
VLOG(3) << "old_op " << old_op->name() << " had " << output_nodes.size()
|
||||
<< " outputs. Moving them to the ScopedAllocatorSplit node.";
|
||||
if (VLOG_IS_ON(2)) {
|
||||
|
@ -93,76 +93,6 @@ NodeMap::NodeMap(GraphDef* graph) {
|
||||
}
|
||||
}
|
||||
|
||||
void NodeMap::RemoveNode(const string& name) {
|
||||
nodes_.erase(NodeName(name));
|
||||
outputs_.erase(NodeName(name));
|
||||
}
|
||||
|
||||
NodeDef* NodeMap::GetNode(const string& name) const {
|
||||
const string node_name = NodeName(name);
|
||||
auto it = nodes_.find(node_name);
|
||||
if (it == nodes_.end()) {
|
||||
VLOG(1) << "Node could not be found: " << name;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool NodeMap::NodeExists(const string& name) const {
|
||||
const string node_name = NodeName(name);
|
||||
return nodes_.find(node_name) != nodes_.end();
|
||||
}
|
||||
|
||||
const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const {
|
||||
auto it = outputs_.find(node_name);
|
||||
if (it == outputs_.end()) {
|
||||
return empty_set_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void NodeMap::AddNode(const string& node_name, NodeDef* node) {
|
||||
auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node));
|
||||
CHECK(ret.second) << "Pair (" << node_name << "," << node
|
||||
<< ") is not inserted because the same key already exists.";
|
||||
}
|
||||
|
||||
void NodeMap::AddOutput(const string& node_name, const string& output_name) {
|
||||
auto output_node = nodes_[NodeName(output_name)];
|
||||
CHECK(output_node) << "Output node " << output_name
|
||||
<< " is missing in NodeMap.";
|
||||
outputs_[node_name].insert(output_node);
|
||||
}
|
||||
|
||||
void NodeMap::RemoveOutput(const string& node_name, const string& output_name) {
|
||||
outputs_[node_name].erase(nodes_[NodeName(output_name)]);
|
||||
}
|
||||
|
||||
void NodeMap::UpdateInput(const string& node_name, const string& old_input_name,
|
||||
const string& new_input_name) {
|
||||
RemoveOutput(NodeName(old_input_name), node_name);
|
||||
AddOutput(NodeName(new_input_name), node_name);
|
||||
}
|
||||
|
||||
void NodeMap::RemoveInputs(const string& node_name) {
|
||||
auto node = nodes_[node_name];
|
||||
for (const auto& input : node->input()) {
|
||||
RemoveOutput(NodeName(input), node->name());
|
||||
}
|
||||
}
|
||||
|
||||
void NodeMap::RemoveOutputs(const string& node_name) {
|
||||
outputs_.erase(node_name);
|
||||
}
|
||||
|
||||
void NodeMap::UpdateOutput(const string& node_name,
|
||||
const string& old_output_name,
|
||||
const string& new_output_name) {
|
||||
std::set<NodeDef*>& outputs = outputs_[node_name];
|
||||
outputs.erase(nodes_[NodeName(old_output_name)]);
|
||||
outputs.insert(nodes_[NodeName(new_output_name)]);
|
||||
}
|
||||
|
||||
string TensorIdToString(const TensorId& tensor_id) {
|
||||
return tensor_id.index() == 0 ? string(tensor_id.node())
|
||||
: tensor_id.ToString();
|
||||
|
@ -18,11 +18,10 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
@ -42,84 +41,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// A utility class to lookup a node and its outputs by node name.
|
||||
class NodeMap {
|
||||
public:
|
||||
// Note: The NodeMap will store pointers to nodes in graph, which may become
|
||||
// invalid if graph is changed.
|
||||
explicit NodeMap(GraphDef* graph);
|
||||
NodeDef* GetNode(const string& name) const;
|
||||
bool NodeExists(const string& name) const;
|
||||
const std::set<NodeDef*>& GetOutputs(const string& node_name) const;
|
||||
// This method doesn't record the outputs of the added node; the outputs need
|
||||
// to be explicitly added by the AddOutput method.
|
||||
void AddNode(const string& name, NodeDef* node);
|
||||
void RemoveNode(const string& name);
|
||||
void UpdateInput(const string& node_name, const string& old_input_name,
|
||||
const string& new_input_name);
|
||||
void AddOutput(const string& node_name, const string& output_name);
|
||||
void RemoveInputs(const string& node_name);
|
||||
void RemoveOutput(const string& node_name, const string& output_name);
|
||||
void RemoveOutputs(const string& node_name);
|
||||
void UpdateOutput(const string& node_name, const string& old_output_name,
|
||||
const string& new_output_name);
|
||||
|
||||
private:
|
||||
const std::set<NodeDef*> empty_set_;
|
||||
absl::node_hash_map<string, NodeDef*> nodes_;
|
||||
absl::node_hash_map<string, std::set<NodeDef*>> outputs_;
|
||||
};
|
||||
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
// quickly answers whether a value is in the vector. Duplicated elements are not
|
||||
// allowed for now.
|
||||
template <class T, class Hash = std::hash<T>>
|
||||
class SetVector {
|
||||
public:
|
||||
// Returns false if value already existed in the set, true otherwise.
|
||||
bool PushBack(const T& value) {
|
||||
if (!set_.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
vector_.push_back(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
T PopBack() {
|
||||
T back = vector_.back();
|
||||
set_.erase(back);
|
||||
vector_.pop_back();
|
||||
return back;
|
||||
}
|
||||
|
||||
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
|
||||
|
||||
bool Empty() const { return vector_.empty(); }
|
||||
|
||||
void Reserve(int64 size) { vector_.reserve(size); }
|
||||
|
||||
private:
|
||||
gtl::FlatSet<T, Hash> set_;
|
||||
std::vector<T> vector_;
|
||||
};
|
||||
|
||||
// Returns formatted string from TensorId specific to grappler. Specifically,
|
||||
// for the 0 port (first output), only the node name is returned.
|
||||
string TensorIdToString(const TensorId& tensor_id);
|
||||
|
||||
// Returns formatted string from SafeTensorId specific to grappler.
|
||||
// Specifically, for the 0 port (first output), only the node name is returned.
|
||||
string SafeTensorIdToString(const SafeTensorId& tensor_id);
|
||||
|
||||
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
|
||||
// the ^ character.
|
||||
bool IsControlInput(const string& name);
|
||||
|
||||
// True iff tensor index refers to a control input.
|
||||
bool IsControlInput(const TensorId& tensor_id);
|
||||
|
||||
// True iff 'name1' and 'name2' refer to the same input.
|
||||
bool IsSameInput(const string& name1, const string& name2);
|
||||
// Utilities for manipulating node name and input strings.
|
||||
|
||||
// Returns the trailing position number (or zero if no number is present) if
|
||||
// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
|
||||
@ -176,6 +98,162 @@ inline int NodePosition(const string& name) {
|
||||
return position;
|
||||
}
|
||||
|
||||
// A utility class to lookup a node and its outputs by node name.
|
||||
class NodeMap {
|
||||
public:
|
||||
// Note: The NodeMap will store pointers to nodes in graph, which may become
|
||||
// invalid if graph is changed.
|
||||
explicit NodeMap(GraphDef* graph);
|
||||
|
||||
// Get unordered list of fanouts from node. Notice, that the order is
|
||||
// non-deterministic.
|
||||
const absl::flat_hash_set<NodeDef*>& GetOutputs(
|
||||
const string& node_name) const {
|
||||
auto it = outputs_.find(node_name);
|
||||
if (it == outputs_.end()) {
|
||||
return empty_set_;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// Get fanouts ordered by name.
|
||||
std::vector<NodeDef*> GetOutputsOrderedByNodeName(
|
||||
const string& node_name) const {
|
||||
std::vector<NodeDef*> result;
|
||||
auto it = outputs_.find(node_name);
|
||||
if (it != outputs_.end()) {
|
||||
const absl::flat_hash_set<NodeDef*>& outputs = it->second;
|
||||
result.reserve(outputs.size());
|
||||
result.assign(outputs.begin(), outputs.end());
|
||||
std::sort(result.begin(), result.end(),
|
||||
[](const NodeDef* n1, const NodeDef* n2) {
|
||||
return n1->name() < n2->name();
|
||||
});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// This method doesn't record the outputs of the added node; the outputs need
|
||||
// to be explicitly added by the AddOutput method.
|
||||
void AddNode(const string& node_name, NodeDef* node) {
|
||||
DCHECK(node != nullptr);
|
||||
auto ret = nodes_.emplace(node_name, node);
|
||||
DCHECK(ret.second)
|
||||
<< "Pair (" << node_name << "," << node
|
||||
<< ") is not inserted because the same key already exists.";
|
||||
}
|
||||
|
||||
void RemoveNode(const string& name) {
|
||||
nodes_.erase(NodeName(name));
|
||||
outputs_.erase(NodeName(name));
|
||||
}
|
||||
|
||||
NodeDef* GetNode(const string& name) const {
|
||||
const string node_name = NodeName(name);
|
||||
auto it = nodes_.find(node_name);
|
||||
if (it == nodes_.end()) {
|
||||
VLOG(1) << "Node could not be found: " << name;
|
||||
return nullptr;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool NodeExists(const string& name) const {
|
||||
const string node_name = NodeName(name);
|
||||
return nodes_.find(node_name) != nodes_.end();
|
||||
}
|
||||
|
||||
void AddOutput(const string& node_name, const string& output_name) {
|
||||
auto output_node = nodes_[NodeName(output_name)];
|
||||
DCHECK(output_node) << "Output node " << output_name
|
||||
<< " is missing in NodeMap.";
|
||||
outputs_[node_name].insert(output_node);
|
||||
}
|
||||
|
||||
void RemoveOutput(const string& node_name, const string& output_name) {
|
||||
outputs_[node_name].erase(nodes_[NodeName(output_name)]);
|
||||
}
|
||||
|
||||
void UpdateInput(const string& node_name, const string& old_input_name,
|
||||
const string& new_input_name) {
|
||||
RemoveOutput(NodeName(old_input_name), node_name);
|
||||
AddOutput(NodeName(new_input_name), node_name);
|
||||
}
|
||||
|
||||
void RemoveInputs(const string& node_name) {
|
||||
auto node = nodes_[node_name];
|
||||
for (const auto& input : node->input()) {
|
||||
RemoveOutput(NodeName(input), node->name());
|
||||
}
|
||||
}
|
||||
|
||||
void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
|
||||
|
||||
void UpdateOutput(const string& node_name, const string& old_output_name,
|
||||
const string& new_output_name) {
|
||||
absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
|
||||
outputs.erase(nodes_[NodeName(old_output_name)]);
|
||||
outputs.insert(nodes_[NodeName(new_output_name)]);
|
||||
}
|
||||
|
||||
private:
|
||||
const absl::flat_hash_set<NodeDef*> empty_set_;
|
||||
absl::node_hash_map<string, NodeDef*> nodes_;
|
||||
absl::node_hash_map<string, absl::flat_hash_set<NodeDef*>> outputs_;
|
||||
};
|
||||
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
// quickly answers whether a value is in the vector. Duplicated elements are not
|
||||
// allowed for now.
|
||||
template <class T, class Hash = std::hash<T>>
|
||||
class SetVector {
|
||||
public:
|
||||
// Returns false if value already existed in the set, true otherwise.
|
||||
bool PushBack(const T& value) {
|
||||
if (!set_.insert(value).second) {
|
||||
return false;
|
||||
}
|
||||
vector_.push_back(value);
|
||||
return true;
|
||||
}
|
||||
|
||||
T PopBack() {
|
||||
T back = vector_.back();
|
||||
set_.erase(back);
|
||||
vector_.pop_back();
|
||||
return back;
|
||||
}
|
||||
|
||||
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
|
||||
|
||||
bool Empty() const { return vector_.empty(); }
|
||||
|
||||
void Reserve(int64 size) { vector_.reserve(size); }
|
||||
|
||||
private:
|
||||
gtl::FlatSet<T, Hash> set_;
|
||||
std::vector<T> vector_;
|
||||
};
|
||||
|
||||
// Returns formatted string from TensorId specific to grappler. Specifically,
|
||||
// for the 0 port (first output), only the node name is returned.
|
||||
string TensorIdToString(const TensorId& tensor_id);
|
||||
|
||||
// Returns formatted string from SafeTensorId specific to grappler.
|
||||
// Specifically, for the 0 port (first output), only the node name is returned.
|
||||
string SafeTensorIdToString(const SafeTensorId& tensor_id);
|
||||
|
||||
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
|
||||
// the ^ character.
|
||||
bool IsControlInput(const string& name);
|
||||
|
||||
// True iff tensor index refers to a control input.
|
||||
bool IsControlInput(const TensorId& tensor_id);
|
||||
|
||||
// True iff 'name1' and 'name2' refer to the same input.
|
||||
bool IsSameInput(const string& name1, const string& name2);
|
||||
|
||||
|
||||
// Add a prefix to a node name with a custom delimiter.
|
||||
string AddPrefixToNodeName(const string& name, const string& prefix,
|
||||
const string& delimiter);
|
||||
|
Loading…
Reference in New Issue
Block a user