[Grappler] Micro optimizations:

1. Use flat_hash_{map,set}, keyed on string_view, in shape inference code and a few grappler utility functions.
2. Get rid of an unnecessary graph copy in remapper.

Speeds up constant folding by about 6.6% and remapper by ~20% on a large inference graph. Total speedup for grappler optimization is about 4.6%

PiperOrigin-RevId: 304428686
Change-Id: I1b83916f2080fa1f86130b250828518faaa03362
This commit is contained in:
A. Unique TensorFlower 2020-04-02 10:31:18 -07:00 committed by TensorFlower Gardener
parent 1e9b684f25
commit eb0998340c
6 changed files with 39 additions and 34 deletions

View File

@ -40,6 +40,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",

View File

@ -66,6 +66,7 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":utils", ":utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/utils:topological_sort",
@ -161,6 +162,7 @@ tf_cuda_library(
deps = [ deps = [
":cost_estimator", ":cost_estimator",
"//third_party/eigen3", "//third_party/eigen3",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -201,7 +201,7 @@ class DisjointSet {
private: private:
Processor<Handle> processor_; Processor<Handle> processor_;
std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>> absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
nodes_; nodes_;
}; };
@ -297,9 +297,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
// This really should be done in an external debugging tool // This really should be done in an external debugging tool
void VerboseLogUnknownDimensionSources( void VerboseLogUnknownDimensionSources(
const GraphDef& graph, const GraphDef& graph,
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>& const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
input_properties_map, input_properties_map,
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>& const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
output_properties_map) { output_properties_map) {
if (!VLOG_IS_ON(2)) { if (!VLOG_IS_ON(2)) {
return; return;
@ -497,9 +497,9 @@ class TopoQueue {
} }
}; };
const std::unordered_map<const NodeDef*, int> TopoOrder( const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
const std::vector<const NodeDef*>& topo_order) const { const std::vector<const NodeDef*>& topo_order) const {
std::unordered_map<const NodeDef*, int> map; absl::flat_hash_map<const NodeDef*, int> map;
map.reserve(topo_order.size()); map.reserve(topo_order.size());
for (int i = 0; i < topo_order.size(); ++i) { for (int i = 0; i < topo_order.size(); ++i) {
map.emplace(topo_order[i], i); map.emplace(topo_order[i], i);
@ -507,7 +507,7 @@ class TopoQueue {
return map; return map;
} }
const std::unordered_map<const NodeDef*, int> topo_order_; const absl::flat_hash_map<const NodeDef*, int> topo_order_;
std::set<NodeAndId, OrderByIdAscending> queue_; std::set<NodeAndId, OrderByIdAscending> queue_;
}; };
@ -599,7 +599,7 @@ class SymbolicShapeRefiner {
public: public:
explicit SymbolicShapeRefiner( explicit SymbolicShapeRefiner(
const GraphView& graph, const GraphView& graph,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports, const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
const bool aggressive_shape_inference) const bool aggressive_shape_inference)
: graph_(graph), : graph_(graph),
function_library_(OpRegistry::Global(), graph.graph()->library()), function_library_(OpRegistry::Global(), graph.graph()->library()),
@ -1917,20 +1917,20 @@ class SymbolicShapeRefiner {
const GraphView& graph_; const GraphView& graph_;
int graph_def_version_; int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_; absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_; absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_; absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
// Store function instantiations only for valid function. If function // Store function instantiations only for valid function. If function
// instantiation failed it will have an `absl::nullopt`. // instantiation failed it will have an `absl::nullopt`.
std::unordered_map<string, absl::optional<GrapplerFunctionItem>> absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
fun_to_grappler_function_item_; fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_; FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_; const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
// Store TensorProtos for tensor value propagation. Note that we use list, not // Store TensorProtos for tensor value propagation. Note that we use deque,
// vector, as we use pointers to the TensorProtos in this container. Vector // not vector, as we use pointers to the TensorProtos in this container.
// may resize and copy the objects into a new buffer, then the existing // Vector may resize and copy the objects into a new buffer, then the existing
// pointers become dangling pointers. // pointers become dangling pointers.
std::list<TensorProto> const_tensors_to_propagate_; std::deque<TensorProto> const_tensors_to_propagate_;
// For more aggressive shape and value inference. // For more aggressive shape and value inference.
bool aggressive_shape_inference_; bool aggressive_shape_inference_;
@ -2093,7 +2093,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
Status GraphProperties::UpdateShapes( Status GraphProperties::UpdateShapes(
SymbolicShapeRefiner* shape_refiner, SymbolicShapeRefiner* shape_refiner,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles, const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
const NodeDef* n, bool* new_shapes) const { const NodeDef* n, bool* new_shapes) const {
if (IsEnter(*n)) { if (IsEnter(*n)) {
// The Enter shape function always forwards an UnknownShape, so do the right // The Enter shape function always forwards an UnknownShape, so do the right
@ -2122,7 +2122,7 @@ Status GraphProperties::UpdateShapes(
// Propagates the shapes in the transitive fan-out of <new_shapes>. // Propagates the shapes in the transitive fan-out of <new_shapes>.
Status GraphProperties::PropagateShapes( Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles, const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
int num_loops) const { int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of // Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algorithm should converge in at most // incorrect shape functions. The algorithm should converge in at most
@ -2221,7 +2221,7 @@ Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
Status GraphProperties::UpdateEnqueue( Status GraphProperties::UpdateEnqueue(
const NodeDef* enqueue_node, const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles, const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes) { SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
auto ctx = shape_refiner->GetNodeContext(enqueue_node); auto ctx = shape_refiner->GetNodeContext(enqueue_node);
if (!ctx) { if (!ctx) {
@ -2272,7 +2272,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
bool include_output_tensor_values) { bool include_output_tensor_values) {
FunctionLibraryDefinition function_library(OpRegistry::Global(), FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library()); item_.graph.library());
std::unordered_map<string, std::unordered_set<int>> fed_ports; absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
if (!assume_valid_feeds) { if (!assume_valid_feeds) {
for (const auto& feed : item_.feed) { for (const auto& feed : item_.feed) {
SafeTensorId tensor_id = ParseTensorName(feed.first); SafeTensorId tensor_id = ParseTensorName(feed.first);
@ -2284,13 +2284,13 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
// List the resources and the nodes using them. Also collect the Merge nodes, // List the resources and the nodes using them. Also collect the Merge nodes,
// fed nodes, and primary inputs. // fed nodes, and primary inputs.
std::unordered_map<const NodeDef*, absl::flat_hash_map<const NodeDef*,
std::pair<std::unordered_set<const NodeDef*>, std::pair<absl::flat_hash_set<const NodeDef*>,
std::unordered_set<const NodeDef*>>> absl::flat_hash_set<const NodeDef*>>>
resources; resources;
std::unordered_set<const NodeDef*> merge_nodes; absl::flat_hash_set<const NodeDef*> merge_nodes;
std::unordered_set<const NodeDef*> fed_nodes; absl::flat_hash_set<const NodeDef*> fed_nodes;
std::unordered_set<const NodeDef*> primary_inputs; absl::flat_hash_set<const NodeDef*> primary_inputs;
int num_loops = 0; int num_loops = 0;
for (const NodeDef& node : item_.graph.node()) { for (const NodeDef& node : item_.graph.node()) {
if (IsQueue(node)) { if (IsQueue(node)) {
@ -2327,7 +2327,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
} }
} }
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles; absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
std::vector<TopologicalDependency> extra_deps; std::vector<TopologicalDependency> extra_deps;
for (const auto& resource : resources) { for (const auto& resource : resources) {
for (const NodeDef* src : resource.second.first) { for (const NodeDef* src : resource.second.first) {

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@ -168,7 +169,7 @@ class GraphProperties {
// queue, and schedule the reprocessing of the queue if needed. // queue, and schedule the reprocessing of the queue if needed.
static Status UpdateEnqueue( static Status UpdateEnqueue(
const NodeDef* enqueue_node, const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
resource_handles, resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes); SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
@ -187,22 +188,22 @@ class GraphProperties {
// Update the shapes for node 'n'. If output shapes for n have changed, // Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'. // enqueue its fanout in 'new_shapes'.
Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, Status UpdateShapes(SymbolicShapeRefiner* shape_refiner,
const std::unordered_map<const NodeDef*, const NodeDef*>& const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
resource_handles, resource_handles,
const NodeDef* n, bool* new_shapes) const; const NodeDef* n, bool* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their // Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached. // transitive fanout until a fixed point is reached.
Status PropagateShapes( Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
const std::unordered_map<const NodeDef*, const NodeDef*>& const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
resource_handles, resource_handles,
int num_loops) const; int num_loops) const;
// Data members // Data members
const GrapplerItem& item_; const GrapplerItem& item_;
std::unordered_map<string, std::vector<OpInfo::TensorProperties>> absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
input_properties_; input_properties_;
std::unordered_map<string, std::vector<OpInfo::TensorProperties>> absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
output_properties_; output_properties_;
const std::vector<OpInfo::TensorProperties> missing_properties_; const std::vector<OpInfo::TensorProperties> missing_properties_;

View File

@ -1662,7 +1662,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
TF_RETURN_IF_ERROR(mutation->Apply()); TF_RETURN_IF_ERROR(mutation->Apply());
*optimized_graph = mutable_item.graph; *optimized_graph = std::move(mutable_item.graph);
return Status::OK(); return Status::OK();
} }

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <queue> #include <queue>
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h" #include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
@ -436,7 +437,7 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
} }
void DedupControlInputs(NodeDef* node) { void DedupControlInputs(NodeDef* node) {
std::unordered_set<string> inputs; absl::flat_hash_set<string> inputs;
int pos = 0; int pos = 0;
while (pos < node->input_size()) { while (pos < node->input_size()) {
const string& input = node->input(pos); const string& input = node->input(pos);