[Grappler] Micro optimizations:
1. Use flat_hash_{map,set}, keyed on string_view, in shape inference code and a few grappler utility functions. 2. Get rid of an unnecessary graph copy in remapper. Speeds up constant folding by about 6.6% and remapper by ~20% on a large inference graph. Total speedup for grappler optimization is about 4.6% PiperOrigin-RevId: 304428686 Change-Id: I1b83916f2080fa1f86130b250828518faaa03362
This commit is contained in:
parent
1e9b684f25
commit
eb0998340c
tensorflow/core/grappler
@ -40,6 +40,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -66,6 +66,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"//tensorflow/core/grappler/utils:functions",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
@ -161,6 +162,7 @@ tf_cuda_library(
|
||||
deps = [
|
||||
":cost_estimator",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -201,7 +201,7 @@ class DisjointSet {
|
||||
|
||||
private:
|
||||
Processor<Handle> processor_;
|
||||
std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
|
||||
absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
|
||||
nodes_;
|
||||
};
|
||||
|
||||
@ -297,9 +297,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
|
||||
// This really should be done in an external debugging tool
|
||||
void VerboseLogUnknownDimensionSources(
|
||||
const GraphDef& graph,
|
||||
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||
const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||
input_properties_map,
|
||||
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||
const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||
output_properties_map) {
|
||||
if (!VLOG_IS_ON(2)) {
|
||||
return;
|
||||
@ -497,9 +497,9 @@ class TopoQueue {
|
||||
}
|
||||
};
|
||||
|
||||
const std::unordered_map<const NodeDef*, int> TopoOrder(
|
||||
const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
|
||||
const std::vector<const NodeDef*>& topo_order) const {
|
||||
std::unordered_map<const NodeDef*, int> map;
|
||||
absl::flat_hash_map<const NodeDef*, int> map;
|
||||
map.reserve(topo_order.size());
|
||||
for (int i = 0; i < topo_order.size(); ++i) {
|
||||
map.emplace(topo_order[i], i);
|
||||
@ -507,7 +507,7 @@ class TopoQueue {
|
||||
return map;
|
||||
}
|
||||
|
||||
const std::unordered_map<const NodeDef*, int> topo_order_;
|
||||
const absl::flat_hash_map<const NodeDef*, int> topo_order_;
|
||||
std::set<NodeAndId, OrderByIdAscending> queue_;
|
||||
};
|
||||
|
||||
@ -599,7 +599,7 @@ class SymbolicShapeRefiner {
|
||||
public:
|
||||
explicit SymbolicShapeRefiner(
|
||||
const GraphView& graph,
|
||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
|
||||
const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
|
||||
const bool aggressive_shape_inference)
|
||||
: graph_(graph),
|
||||
function_library_(OpRegistry::Global(), graph.graph()->library()),
|
||||
@ -1917,20 +1917,20 @@ class SymbolicShapeRefiner {
|
||||
|
||||
const GraphView& graph_;
|
||||
int graph_def_version_;
|
||||
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
|
||||
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
|
||||
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
|
||||
absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
|
||||
absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
|
||||
absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
|
||||
// Store function instantiations only for valid function. If function
|
||||
// instantiation failed it will have an `absl::nullopt`.
|
||||
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
|
||||
absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
|
||||
fun_to_grappler_function_item_;
|
||||
FunctionLibraryDefinition function_library_;
|
||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
|
||||
// Store TensorProtos for tensor value propagation. Note that we use list, not
|
||||
// vector, as we use pointers to the TensorProtos in this container. Vector
|
||||
// may resize and copy the objects into a new buffer, then the existing
|
||||
const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
|
||||
// Store TensorProtos for tensor value propagation. Note that we use deque,
|
||||
// not vector, as we use pointers to the TensorProtos in this container.
|
||||
// Vector may resize and copy the objects into a new buffer, then the existing
|
||||
// pointers become dangling pointers.
|
||||
std::list<TensorProto> const_tensors_to_propagate_;
|
||||
std::deque<TensorProto> const_tensors_to_propagate_;
|
||||
|
||||
// For more aggressive shape and value inference.
|
||||
bool aggressive_shape_inference_;
|
||||
@ -2093,7 +2093,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
|
||||
|
||||
Status GraphProperties::UpdateShapes(
|
||||
SymbolicShapeRefiner* shape_refiner,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
const NodeDef* n, bool* new_shapes) const {
|
||||
if (IsEnter(*n)) {
|
||||
// The Enter shape function always forwards an UnknownShape, so do the right
|
||||
@ -2122,7 +2122,7 @@ Status GraphProperties::UpdateShapes(
|
||||
// Propagates the shapes in the transitive fan-out of <new_shapes>.
|
||||
Status GraphProperties::PropagateShapes(
|
||||
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
int num_loops) const {
|
||||
// Limit the number of iterations to prevent infinite loops in the presence of
|
||||
// incorrect shape functions. The algorithm should converge in at most
|
||||
@ -2221,7 +2221,7 @@ Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
|
||||
|
||||
Status GraphProperties::UpdateEnqueue(
|
||||
const NodeDef* enqueue_node,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||
SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
|
||||
auto ctx = shape_refiner->GetNodeContext(enqueue_node);
|
||||
if (!ctx) {
|
||||
@ -2272,7 +2272,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
bool include_output_tensor_values) {
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
item_.graph.library());
|
||||
std::unordered_map<string, std::unordered_set<int>> fed_ports;
|
||||
absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
|
||||
if (!assume_valid_feeds) {
|
||||
for (const auto& feed : item_.feed) {
|
||||
SafeTensorId tensor_id = ParseTensorName(feed.first);
|
||||
@ -2284,13 +2284,13 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
|
||||
// List the resources and the nodes using them. Also collect the Merge nodes,
|
||||
// fed nodes, and primary inputs.
|
||||
std::unordered_map<const NodeDef*,
|
||||
std::pair<std::unordered_set<const NodeDef*>,
|
||||
std::unordered_set<const NodeDef*>>>
|
||||
absl::flat_hash_map<const NodeDef*,
|
||||
std::pair<absl::flat_hash_set<const NodeDef*>,
|
||||
absl::flat_hash_set<const NodeDef*>>>
|
||||
resources;
|
||||
std::unordered_set<const NodeDef*> merge_nodes;
|
||||
std::unordered_set<const NodeDef*> fed_nodes;
|
||||
std::unordered_set<const NodeDef*> primary_inputs;
|
||||
absl::flat_hash_set<const NodeDef*> merge_nodes;
|
||||
absl::flat_hash_set<const NodeDef*> fed_nodes;
|
||||
absl::flat_hash_set<const NodeDef*> primary_inputs;
|
||||
int num_loops = 0;
|
||||
for (const NodeDef& node : item_.graph.node()) {
|
||||
if (IsQueue(node)) {
|
||||
@ -2327,7 +2327,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
|
||||
absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
|
||||
std::vector<TopologicalDependency> extra_deps;
|
||||
for (const auto& resource : resources) {
|
||||
for (const NodeDef* src : resource.second.first) {
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||
@ -168,7 +169,7 @@ class GraphProperties {
|
||||
// queue, and schedule the reprocessing of the queue if needed.
|
||||
static Status UpdateEnqueue(
|
||||
const NodeDef* enqueue_node,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||
resource_handles,
|
||||
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
|
||||
|
||||
@ -187,22 +188,22 @@ class GraphProperties {
|
||||
// Update the shapes for node 'n'. If output shapes for n have changed,
|
||||
// enqueue its fanout in 'new_shapes'.
|
||||
Status UpdateShapes(SymbolicShapeRefiner* shape_refiner,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||
resource_handles,
|
||||
const NodeDef* n, bool* new_shapes) const;
|
||||
// Propagate the shapes for the nodes enqueued in new_shapes and their
|
||||
// transitive fanout until a fixed point is reached.
|
||||
Status PropagateShapes(
|
||||
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
||||
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||
resource_handles,
|
||||
int num_loops) const;
|
||||
|
||||
// Data members
|
||||
const GrapplerItem& item_;
|
||||
std::unordered_map<string, std::vector<OpInfo::TensorProperties>>
|
||||
absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
|
||||
input_properties_;
|
||||
std::unordered_map<string, std::vector<OpInfo::TensorProperties>>
|
||||
absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
|
||||
output_properties_;
|
||||
const std::vector<OpInfo::TensorProperties> missing_properties_;
|
||||
|
||||
|
@ -1662,7 +1662,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
TF_RETURN_IF_ERROR(mutation->Apply());
|
||||
|
||||
*optimized_graph = mutable_item.graph;
|
||||
*optimized_graph = std::move(mutable_item.graph);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
@ -436,7 +437,7 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
|
||||
}
|
||||
|
||||
void DedupControlInputs(NodeDef* node) {
|
||||
std::unordered_set<string> inputs;
|
||||
absl::flat_hash_set<string> inputs;
|
||||
int pos = 0;
|
||||
while (pos < node->input_size()) {
|
||||
const string& input = node->input(pos);
|
||||
|
Loading…
Reference in New Issue
Block a user