STT-tensorflow/tensorflow/core/grappler/utils.h
A. Unique TensorFlower 4b1c4564e1 [TensorFlow PE] Fix static analysis warnings in grappler: std::set<NodeDef*> does not make sense, since ordering of pointers is not deterministic. Switch to using absl::flat_hash_set, which is also faster. Add a method GetOutputsOrderedByNodeName for uses where deterministic ordering is required.
This, combined with inlining the methods of NodeMap also gives a ~4.8% speedup of grappler on a large inference graph.

PiperOrigin-RevId: 304515494
Change-Id: Ia834d2019a142b470333dc8dd4b8151694a6510d
2020-04-02 18:11:47 -07:00

371 lines
13 KiB
C++

/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_H_
#include <functional>
#include <iterator>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
// Utilities for manipulating node name and input strings.
// Returns the trailing position number (or zero if no number is present) if
// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
// Returns -2 if input_name is empty or NodeName(input_name) is not equal to
// node_name.
inline int NodePositionIfSameNode(absl::string_view input_name,
absl::string_view node_name) {
bool is_control = absl::StartsWith(input_name, "^");
if (is_control) input_name.remove_prefix(1);
if (input_name.empty() || node_name.empty() ||
input_name.size() < node_name.size()) {
return -2;
}
TensorId id = ParseTensorName(input_name);
if (id.first != node_name) return -2;
if (is_control) return -1;
return id.second;
}
// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(absl::string_view name,
int* position) {
const bool is_control = absl::StartsWith(name, "^");
TensorId id = ParseTensorName(name);
if (position) {
*position = is_control ? -1 : id.second;
}
if (is_control && id.second >= 0) {
id.first.remove_prefix(1);
}
return id.first;
}
// Returns the node name and position in a single call.
inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
inline StringPiece NodeNameAsStringPiece(const string& name) {
return ParseNodeNameAsStringPiece(name, nullptr);
}
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
inline string NodeName(const string& name) {
return string(NodeNameAsStringPiece(name));
}
inline int NodePosition(const string& name) {
int position;
ParseNodeNameAsStringPiece(name, &position);
return position;
}
// A utility class to lookup a node and its outputs by node name.
class NodeMap {
public:
// Note: The NodeMap will store pointers to nodes in graph, which may become
// invalid if graph is changed.
explicit NodeMap(GraphDef* graph);
// Get unordered list of fanouts from node. Notice, that the order is
// non-deterministic.
const absl::flat_hash_set<NodeDef*>& GetOutputs(
const string& node_name) const {
auto it = outputs_.find(node_name);
if (it == outputs_.end()) {
return empty_set_;
}
return it->second;
}
// Get fanouts ordered by name.
std::vector<NodeDef*> GetOutputsOrderedByNodeName(
const string& node_name) const {
std::vector<NodeDef*> result;
auto it = outputs_.find(node_name);
if (it != outputs_.end()) {
const absl::flat_hash_set<NodeDef*>& outputs = it->second;
result.reserve(outputs.size());
result.assign(outputs.begin(), outputs.end());
std::sort(result.begin(), result.end(),
[](const NodeDef* n1, const NodeDef* n2) {
return n1->name() < n2->name();
});
}
return result;
}
// This method doesn't record the outputs of the added node; the outputs need
// to be explicitly added by the AddOutput method.
void AddNode(const string& node_name, NodeDef* node) {
DCHECK(node != nullptr);
auto ret = nodes_.emplace(node_name, node);
DCHECK(ret.second)
<< "Pair (" << node_name << "," << node
<< ") is not inserted because the same key already exists.";
}
void RemoveNode(const string& name) {
nodes_.erase(NodeName(name));
outputs_.erase(NodeName(name));
}
NodeDef* GetNode(const string& name) const {
const string node_name = NodeName(name);
auto it = nodes_.find(node_name);
if (it == nodes_.end()) {
VLOG(1) << "Node could not be found: " << name;
return nullptr;
}
return it->second;
}
bool NodeExists(const string& name) const {
const string node_name = NodeName(name);
return nodes_.find(node_name) != nodes_.end();
}
void AddOutput(const string& node_name, const string& output_name) {
auto output_node = nodes_[NodeName(output_name)];
DCHECK(output_node) << "Output node " << output_name
<< " is missing in NodeMap.";
outputs_[node_name].insert(output_node);
}
void RemoveOutput(const string& node_name, const string& output_name) {
outputs_[node_name].erase(nodes_[NodeName(output_name)]);
}
void UpdateInput(const string& node_name, const string& old_input_name,
const string& new_input_name) {
RemoveOutput(NodeName(old_input_name), node_name);
AddOutput(NodeName(new_input_name), node_name);
}
void RemoveInputs(const string& node_name) {
auto node = nodes_[node_name];
for (const auto& input : node->input()) {
RemoveOutput(NodeName(input), node->name());
}
}
void RemoveOutputs(const string& node_name) { outputs_.erase(node_name); }
void UpdateOutput(const string& node_name, const string& old_output_name,
const string& new_output_name) {
absl::flat_hash_set<NodeDef*>& outputs = outputs_[node_name];
outputs.erase(nodes_[NodeName(old_output_name)]);
outputs.insert(nodes_[NodeName(new_output_name)]);
}
private:
const absl::flat_hash_set<NodeDef*> empty_set_;
absl::node_hash_map<string, NodeDef*> nodes_;
absl::node_hash_map<string, absl::flat_hash_set<NodeDef*>> outputs_;
};
// A vector with a set. The set stores the same elements as the vector, and
// quickly answers whether a value is in the vector. Duplicated elements are not
// allowed for now.
template <class T, class Hash = std::hash<T>>
class SetVector {
public:
// Returns false if value already existed in the set, true otherwise.
bool PushBack(const T& value) {
if (!set_.insert(value).second) {
return false;
}
vector_.push_back(value);
return true;
}
T PopBack() {
T back = vector_.back();
set_.erase(back);
vector_.pop_back();
return back;
}
bool Exists(const T& value) const { return set_.find(value) != set_.end(); }
bool Empty() const { return vector_.empty(); }
void Reserve(int64 size) { vector_.reserve(size); }
private:
gtl::FlatSet<T, Hash> set_;
std::vector<T> vector_;
};
// Returns formatted string from TensorId specific to grappler. Specifically,
// for the 0 port (first output), only the node name is returned.
string TensorIdToString(const TensorId& tensor_id);
// Returns formatted string from SafeTensorId specific to grappler.
// Specifically, for the 0 port (first output), only the node name is returned.
string SafeTensorIdToString(const SafeTensorId& tensor_id);
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
// True iff tensor index refers to a control input.
bool IsControlInput(const TensorId& tensor_id);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter);
// Add a prefix to a node name.
string AddPrefixToNodeName(const string& name, const string& prefix);
// Executes a 'fn' in the 'thread_pool'. The method waits for the configured
// timeout (in milliseconds) for 'fn' to complete, before returning false.
//
// If returning false, the 'fn' may still continue to execute in the
// thread-pool. It is the responsibility of the caller to reset the thread-pool
// as appropriate.
bool ExecuteWithTimeout(std::function<void()> fn, int64 timeout_in_ms,
thread::ThreadPool* thread_pool);
// Returns the node name prefixed with conventional symbol '^'
// for control dependency, given a NodeDef.
string AsControlDependency(const NodeDef& node);
// Returns the node name prefixed with conventional symbol '^'
// for control dependency, given a node name
string AsControlDependency(const string& node);
// Returns true if the node is assigned to run on CPU device.
bool NodeIsOnCpu(const NodeDef* node);
// Returns true if the node is assigned to run on GPU device.
bool NodeIsOnGpu(const NodeDef* node);
// Returns the number of outputs of a node according to its OpDef. Note that
// some of the outputs may be unconnected.
int NumOutputs(const NodeDef& node, GraphDef* graph);
// Returns true iff the node has at least one control input.
bool HasControlInputs(const NodeDef& node);
// Returns true iff the node has at least one regular input.
bool HasRegularInputs(const NodeDef& node);
// Returns true iff the node has at least one regular output.
bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
// Returns true iff the node has at least one control output.
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected control inputs.
int NumControlInputs(const NodeDef& node);
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);
// Number of connected control outputs.
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected non-control data outputs (Ops that consume output tensor
// data, not just it's shape).
int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
// Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node);
// Returns an error if an attribute with the given key does not exist in node.
Status CheckAttrExists(const NodeDef& node, const string& key);
// Returns an error if attributes with the given keys do not exist in node.
Status CheckAttrsExist(const NodeDef& node, absl::Span<const string> keys);
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& type_attr);
// Returns the last node in the simple chain starting at source and traversing
// through the input(0) edge from each node as long as the next node satisfies
// the predicate given in pred_fn. If no nodes satisfy the predicate, &source
// will be returned. Example: For the chain
// source <- a <- b <- ... <- y <- z
// where
// pred_fn(a) = pred_fn(b) = ... = pred_fn(y) = true,
// pred_fn(z) = false,
// the return value will be a pointer to y.
NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map,
bool follow_control_input,
const std::function<bool(const NodeDef&)>& pred_fn);
// Permute the nodes of graph in place according to the permutation.
void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
bool invert_permutation);
// Returns Status::OK() if a kernel is registered for node.op() on the device
// type corresponding to node.device().
Status IsKernelRegisteredForNode(
absl::string_view node_name, bool has_experimental_debug_info,
const NodeDef_ExperimentalDebugInfo& experimental_debug_info,
absl::string_view node_op, absl::string_view node_device,
AttrSlice node_attrs);
Status IsKernelRegisteredForNode(const NodeDef& node);
Status SetTensorValue(DataType dtype, int value, Tensor* tensor);
void EraseNodesFromGraph(const std::set<int>& nodes_to_delete, GraphDef* graph);
void EraseNodesFromGraph(std::vector<int>&& nodes_to_delete, GraphDef* graph);
void EraseNodesFromGraph(const std::set<string>& nodes_to_delete,
GraphDef* graph);
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_H_