[Grappler] Micro optimizations:
1. Use flat_hash_{map,set}, keyed on string_view, in shape inference code and a few grappler utility functions.
2. Get rid of an unnecessary graph copy in remapper.
Speeds up constant folding by about 6.6% and remapper by ~20% on a large inference graph. Total speedup for grappler optimization is about 4.6%
PiperOrigin-RevId: 304428686
Change-Id: I1b83916f2080fa1f86130b250828518faaa03362
This commit is contained in:
parent
1e9b684f25
commit
eb0998340c
@ -40,6 +40,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
|||||||
@ -66,6 +66,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":utils",
|
":utils",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/types:optional",
|
"@com_google_absl//absl/types:optional",
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
"//tensorflow/core/grappler/utils:topological_sort",
|
"//tensorflow/core/grappler/utils:topological_sort",
|
||||||
@ -161,6 +162,7 @@ tf_cuda_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":cost_estimator",
|
":cost_estimator",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class DisjointSet {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Processor<Handle> processor_;
|
Processor<Handle> processor_;
|
||||||
std::unordered_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
|
absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
|
||||||
nodes_;
|
nodes_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -297,9 +297,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
|
|||||||
// This really should be done in an external debugging tool
|
// This really should be done in an external debugging tool
|
||||||
void VerboseLogUnknownDimensionSources(
|
void VerboseLogUnknownDimensionSources(
|
||||||
const GraphDef& graph,
|
const GraphDef& graph,
|
||||||
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
|
const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||||
input_properties_map,
|
input_properties_map,
|
||||||
const std::unordered_map<string, std::vector<OpInfo::TensorProperties>>&
|
const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
|
||||||
output_properties_map) {
|
output_properties_map) {
|
||||||
if (!VLOG_IS_ON(2)) {
|
if (!VLOG_IS_ON(2)) {
|
||||||
return;
|
return;
|
||||||
@ -497,9 +497,9 @@ class TopoQueue {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::unordered_map<const NodeDef*, int> TopoOrder(
|
const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
|
||||||
const std::vector<const NodeDef*>& topo_order) const {
|
const std::vector<const NodeDef*>& topo_order) const {
|
||||||
std::unordered_map<const NodeDef*, int> map;
|
absl::flat_hash_map<const NodeDef*, int> map;
|
||||||
map.reserve(topo_order.size());
|
map.reserve(topo_order.size());
|
||||||
for (int i = 0; i < topo_order.size(); ++i) {
|
for (int i = 0; i < topo_order.size(); ++i) {
|
||||||
map.emplace(topo_order[i], i);
|
map.emplace(topo_order[i], i);
|
||||||
@ -507,7 +507,7 @@ class TopoQueue {
|
|||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<const NodeDef*, int> topo_order_;
|
const absl::flat_hash_map<const NodeDef*, int> topo_order_;
|
||||||
std::set<NodeAndId, OrderByIdAscending> queue_;
|
std::set<NodeAndId, OrderByIdAscending> queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -599,7 +599,7 @@ class SymbolicShapeRefiner {
|
|||||||
public:
|
public:
|
||||||
explicit SymbolicShapeRefiner(
|
explicit SymbolicShapeRefiner(
|
||||||
const GraphView& graph,
|
const GraphView& graph,
|
||||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
|
const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
|
||||||
const bool aggressive_shape_inference)
|
const bool aggressive_shape_inference)
|
||||||
: graph_(graph),
|
: graph_(graph),
|
||||||
function_library_(OpRegistry::Global(), graph.graph()->library()),
|
function_library_(OpRegistry::Global(), graph.graph()->library()),
|
||||||
@ -1917,20 +1917,20 @@ class SymbolicShapeRefiner {
|
|||||||
|
|
||||||
const GraphView& graph_;
|
const GraphView& graph_;
|
||||||
int graph_def_version_;
|
int graph_def_version_;
|
||||||
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
|
absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
|
||||||
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
|
absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
|
||||||
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
|
absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
|
||||||
// Store function instantiations only for valid function. If function
|
// Store function instantiations only for valid function. If function
|
||||||
// instantiation failed it will have an `absl::nullopt`.
|
// instantiation failed it will have an `absl::nullopt`.
|
||||||
std::unordered_map<string, absl::optional<GrapplerFunctionItem>>
|
absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
|
||||||
fun_to_grappler_function_item_;
|
fun_to_grappler_function_item_;
|
||||||
FunctionLibraryDefinition function_library_;
|
FunctionLibraryDefinition function_library_;
|
||||||
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
|
const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
|
||||||
// Store TensorProtos for tensor value propagation. Note that we use list, not
|
// Store TensorProtos for tensor value propagation. Note that we use deque,
|
||||||
// vector, as we use pointers to the TensorProtos in this container. Vector
|
// not vector, as we use pointers to the TensorProtos in this container.
|
||||||
// may resize and copy the objects into a new buffer, then the existing
|
// Vector may resize and copy the objects into a new buffer, then the existing
|
||||||
// pointers become dangling pointers.
|
// pointers become dangling pointers.
|
||||||
std::list<TensorProto> const_tensors_to_propagate_;
|
std::deque<TensorProto> const_tensors_to_propagate_;
|
||||||
|
|
||||||
// For more aggressive shape and value inference.
|
// For more aggressive shape and value inference.
|
||||||
bool aggressive_shape_inference_;
|
bool aggressive_shape_inference_;
|
||||||
@ -2093,7 +2093,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
|
|||||||
|
|
||||||
Status GraphProperties::UpdateShapes(
|
Status GraphProperties::UpdateShapes(
|
||||||
SymbolicShapeRefiner* shape_refiner,
|
SymbolicShapeRefiner* shape_refiner,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||||
const NodeDef* n, bool* new_shapes) const {
|
const NodeDef* n, bool* new_shapes) const {
|
||||||
if (IsEnter(*n)) {
|
if (IsEnter(*n)) {
|
||||||
// The Enter shape function always forwards an UnknownShape, so do the right
|
// The Enter shape function always forwards an UnknownShape, so do the right
|
||||||
@ -2122,7 +2122,7 @@ Status GraphProperties::UpdateShapes(
|
|||||||
// Propagates the shapes in the transitive fan-out of <new_shapes>.
|
// Propagates the shapes in the transitive fan-out of <new_shapes>.
|
||||||
Status GraphProperties::PropagateShapes(
|
Status GraphProperties::PropagateShapes(
|
||||||
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||||
int num_loops) const {
|
int num_loops) const {
|
||||||
// Limit the number of iterations to prevent infinite loops in the presence of
|
// Limit the number of iterations to prevent infinite loops in the presence of
|
||||||
// incorrect shape functions. The algorithm should converge in at most
|
// incorrect shape functions. The algorithm should converge in at most
|
||||||
@ -2221,7 +2221,7 @@ Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
|
|||||||
|
|
||||||
Status GraphProperties::UpdateEnqueue(
|
Status GraphProperties::UpdateEnqueue(
|
||||||
const NodeDef* enqueue_node,
|
const NodeDef* enqueue_node,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
|
||||||
SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
|
SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
|
||||||
auto ctx = shape_refiner->GetNodeContext(enqueue_node);
|
auto ctx = shape_refiner->GetNodeContext(enqueue_node);
|
||||||
if (!ctx) {
|
if (!ctx) {
|
||||||
@ -2272,7 +2272,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
|||||||
bool include_output_tensor_values) {
|
bool include_output_tensor_values) {
|
||||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||||
item_.graph.library());
|
item_.graph.library());
|
||||||
std::unordered_map<string, std::unordered_set<int>> fed_ports;
|
absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
|
||||||
if (!assume_valid_feeds) {
|
if (!assume_valid_feeds) {
|
||||||
for (const auto& feed : item_.feed) {
|
for (const auto& feed : item_.feed) {
|
||||||
SafeTensorId tensor_id = ParseTensorName(feed.first);
|
SafeTensorId tensor_id = ParseTensorName(feed.first);
|
||||||
@ -2284,13 +2284,13 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
|||||||
|
|
||||||
// List the resources and the nodes using them. Also collect the Merge nodes,
|
// List the resources and the nodes using them. Also collect the Merge nodes,
|
||||||
// fed nodes, and primary inputs.
|
// fed nodes, and primary inputs.
|
||||||
std::unordered_map<const NodeDef*,
|
absl::flat_hash_map<const NodeDef*,
|
||||||
std::pair<std::unordered_set<const NodeDef*>,
|
std::pair<absl::flat_hash_set<const NodeDef*>,
|
||||||
std::unordered_set<const NodeDef*>>>
|
absl::flat_hash_set<const NodeDef*>>>
|
||||||
resources;
|
resources;
|
||||||
std::unordered_set<const NodeDef*> merge_nodes;
|
absl::flat_hash_set<const NodeDef*> merge_nodes;
|
||||||
std::unordered_set<const NodeDef*> fed_nodes;
|
absl::flat_hash_set<const NodeDef*> fed_nodes;
|
||||||
std::unordered_set<const NodeDef*> primary_inputs;
|
absl::flat_hash_set<const NodeDef*> primary_inputs;
|
||||||
int num_loops = 0;
|
int num_loops = 0;
|
||||||
for (const NodeDef& node : item_.graph.node()) {
|
for (const NodeDef& node : item_.graph.node()) {
|
||||||
if (IsQueue(node)) {
|
if (IsQueue(node)) {
|
||||||
@ -2327,7 +2327,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unordered_map<const NodeDef*, const NodeDef*> resource_handles;
|
absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
|
||||||
std::vector<TopologicalDependency> extra_deps;
|
std::vector<TopologicalDependency> extra_deps;
|
||||||
for (const auto& resource : resources) {
|
for (const auto& resource : resources) {
|
||||||
for (const NodeDef* src : resource.second.first) {
|
for (const NodeDef* src : resource.second.first) {
|
||||||
|
|||||||
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
#include "tensorflow/core/grappler/clusters/cluster.h"
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||||
@ -168,7 +169,7 @@ class GraphProperties {
|
|||||||
// queue, and schedule the reprocessing of the queue if needed.
|
// queue, and schedule the reprocessing of the queue if needed.
|
||||||
static Status UpdateEnqueue(
|
static Status UpdateEnqueue(
|
||||||
const NodeDef* enqueue_node,
|
const NodeDef* enqueue_node,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||||
resource_handles,
|
resource_handles,
|
||||||
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
|
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
|
||||||
|
|
||||||
@ -187,22 +188,22 @@ class GraphProperties {
|
|||||||
// Update the shapes for node 'n'. If output shapes for n have changed,
|
// Update the shapes for node 'n'. If output shapes for n have changed,
|
||||||
// enqueue its fanout in 'new_shapes'.
|
// enqueue its fanout in 'new_shapes'.
|
||||||
Status UpdateShapes(SymbolicShapeRefiner* shape_refiner,
|
Status UpdateShapes(SymbolicShapeRefiner* shape_refiner,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||||
resource_handles,
|
resource_handles,
|
||||||
const NodeDef* n, bool* new_shapes) const;
|
const NodeDef* n, bool* new_shapes) const;
|
||||||
// Propagate the shapes for the nodes enqueued in new_shapes and their
|
// Propagate the shapes for the nodes enqueued in new_shapes and their
|
||||||
// transitive fanout until a fixed point is reached.
|
// transitive fanout until a fixed point is reached.
|
||||||
Status PropagateShapes(
|
Status PropagateShapes(
|
||||||
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
|
||||||
const std::unordered_map<const NodeDef*, const NodeDef*>&
|
const absl::flat_hash_map<const NodeDef*, const NodeDef*>&
|
||||||
resource_handles,
|
resource_handles,
|
||||||
int num_loops) const;
|
int num_loops) const;
|
||||||
|
|
||||||
// Data members
|
// Data members
|
||||||
const GrapplerItem& item_;
|
const GrapplerItem& item_;
|
||||||
std::unordered_map<string, std::vector<OpInfo::TensorProperties>>
|
absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
|
||||||
input_properties_;
|
input_properties_;
|
||||||
std::unordered_map<string, std::vector<OpInfo::TensorProperties>>
|
absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>
|
||||||
output_properties_;
|
output_properties_;
|
||||||
const std::vector<OpInfo::TensorProperties> missing_properties_;
|
const std::vector<OpInfo::TensorProperties> missing_properties_;
|
||||||
|
|
||||||
|
|||||||
@ -1662,7 +1662,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(mutation->Apply());
|
TF_RETURN_IF_ERROR(mutation->Apply());
|
||||||
|
|
||||||
*optimized_graph = mutable_item.graph;
|
*optimized_graph = std::move(mutable_item.graph);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/strings/match.h"
|
#include "absl/strings/match.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
@ -436,7 +437,7 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void DedupControlInputs(NodeDef* node) {
|
void DedupControlInputs(NodeDef* node) {
|
||||||
std::unordered_set<string> inputs;
|
absl::flat_hash_set<string> inputs;
|
||||||
int pos = 0;
|
int pos = 0;
|
||||||
while (pos < node->input_size()) {
|
while (pos < node->input_size()) {
|
||||||
const string& input = node->input(pos);
|
const string& input = node->input(pos);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user