diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 3a1ab28dae9..512c3b07b40 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -151,6 +151,8 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/utils:functions", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index 73c950b3fce..d074676c3d4 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -15,10 +15,11 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/function_optimizer.h" -#include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_replace.h" #include "absl/strings/substitute.h" @@ -163,10 +164,10 @@ struct FunctionSpecializationSignature { string func_name; bool is_in_fetch_set; - gtl::FlatSet active_outputs; - std::unordered_map type_parameters; - std::unordered_map body_parameters; - std::unordered_map const_inputs; + absl::flat_hash_set active_outputs; + absl::flat_hash_map type_parameters; + absl::flat_hash_map body_parameters; + absl::flat_hash_map const_inputs; bool operator==(const FunctionSpecializationSignature& other) const { bool equals = func_name == other.func_name && @@ -189,48 +190,45 @@ struct FunctionSpecializationSignature { return true; } - // TODO(ezhulenev): Migrate to AbslHashValue. - // TODO(ezhulenev): Optimize performance by computing hashes of unordered - // values first, and then compute a hash of sorted hashes. - struct Hash { - uint64 operator()(FunctionSpecializationSignature const& s) const { - uint64 h = Hash64(s.func_name); - h = Hash64Combine(std::hash()(s.is_in_fetch_set), h); + template + friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) { + H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set); - // Use std::set/std::map for deterministic iteration order. + // First pre-compute hashes for all values in collections with + // non-deterministic iteration order. + std::vector hashes; + hashes.reserve(s.active_outputs.size() // + + s.type_parameters.size() * 2 // + + s.body_parameters.size() * 2 // + + s.const_inputs.size() * 2); - std::set active_outputs(s.active_outputs.begin(), - s.active_outputs.end()); - for (const auto& active_output : active_outputs) { - h = Hash64Combine(std::hash()(active_output), h); - } + absl::c_transform(s.active_outputs, std::back_inserter(hashes), + hash()); - std::map types(s.type_parameters.begin(), - s.type_parameters.end()); - for (const auto& pair : types) { - AttrValue attr_value; - attr_value.set_type(pair.second); - h = Hash64Combine(Hash64(pair.first), h); - h = Hash64Combine(AttrValueHash(attr_value), h); - } + using TypeParam = std::pair; + absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) { + AttrValue attr_value; + attr_value.set_type(type_param.second); + hashes.push_back(Hash64(type_param.first)); + hashes.push_back(AttrValueHash(attr_value)); + }); - std::map body(s.body_parameters.begin(), - s.body_parameters.end()); - for (const auto& pair : body) { - h = Hash64Combine(Hash64(pair.first), h); - h = Hash64Combine(FastAttrValueHash(pair.second), h); - } + using BodyParam = std::pair; + absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) { + hashes.push_back(Hash64(body_param.first)); + hashes.push_back(FastAttrValueHash(body_param.second)); + }); - std::map inputs(s.const_inputs.begin(), - s.const_inputs.end()); - for (const auto& pair : inputs) { - h = Hash64Combine(std::hash()(pair.first), h); - h = Hash64Combine(Hash64(pair.second), h); - } + using ConstInput = std::pair; + absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) { + hashes.push_back(hash()(const_input.first)); + hashes.push_back(Hash64(const_input.second)); + }); - return h; - } - }; + // Combine all pre-computed hashes in a deterministic order. + absl::c_sort(hashes); + return H::combine_contiguous(std::move(base), hashes.data(), hashes.size()); + } }; struct FunctionSpecialization { @@ -238,13 +236,13 @@ struct FunctionSpecialization { // True if the function caller node is in GrapplerItem fetch set. bool is_in_fetch_set; // Names of the tensors that were pushed down into the function body. - gtl::FlatSet const_inputs; + absl::flat_hash_set const_inputs; // Control dependencies of pushed down const inputs have to be attached to // function caller node. - gtl::FlatSet control_deps; + absl::flat_hash_set control_deps; // Output tensors (ports) that consumed by other nodes in the graph or in a // GrapplerItem fetch set. - gtl::FlatSet active_outputs; + absl::flat_hash_set active_outputs; // Mapping from original function output port to the output port of // specialized function. If function specialization changes the number of // function outputs it's required to update all node consumers. @@ -285,12 +283,13 @@ class FunctionOptimizerContext { return flr_; } - const gtl::FlatMap& + const absl::flat_hash_map& tensor_mapping() const { return tensor_mapping_; } - const gtl::FlatMap>& control_overrides() const { + const absl::flat_hash_map>& control_overrides() + const { return control_overrides_; } @@ -298,7 +297,9 @@ class FunctionOptimizerContext { const string& grappler_item_id() const { return grappler_item_id_; } - const gtl::FlatSet& fetch_tensors() const { return fetch_tensors_; } + const absl::flat_hash_set& fetch_tensors() const { + return fetch_tensors_; + } const DeviceSet* devices() const { // Create fake devices lazily only if we need a DeviceSet. @@ -365,7 +366,7 @@ class FunctionOptimizerContext { private: void InitializeTrulyConstNodes(const GrapplerItem& item) { - gtl::FlatSet feed_nodes; + absl::flat_hash_set feed_nodes; for (const auto& feed : item.feed) { feed_nodes.insert(NodeName(feed.first)); } @@ -411,7 +412,7 @@ class FunctionOptimizerContext { FunctionLibraryRuntime* flr_ = nullptr; // Fully defined names of the devices available to the GrapplerItem. - const gtl::FlatSet available_device_names_; + const absl::flat_hash_set available_device_names_; // List of available `FakedDevices` (lazily initialized, see devices()). mutable std::vector> available_devices_; @@ -421,16 +422,15 @@ class FunctionOptimizerContext { mutable DeviceSet available_device_set_; // Nodes that are Const and not in feed. - std::unordered_map truly_const_nodes_; + absl::flat_hash_map truly_const_nodes_; // Specialized functions. - std::unordered_map + absl::flat_hash_map specialized_functions_; // GrapplerItem.fetch is a vector of tensors. - gtl::FlatSet fetch_tensors_; // format: node_name:port - gtl::FlatSet fetch_nodes_; // format: node_name + absl::flat_hash_set fetch_tensors_; // format: node_name:port + absl::flat_hash_set fetch_nodes_; // format: node_name // After function inlining and specialization, the optimized graph might be in // invalid state, nodes can read from non-existing function call nodes that @@ -439,7 +439,7 @@ class FunctionOptimizerContext { // // Tensor mapping that has to be applied to the graph after all functions // optimizations (invalidated tensor id -> optimized graph tensor id). - gtl::FlatMap + absl::flat_hash_map tensor_mapping_; // When we inline a function into the optimized graph, we no longer have the @@ -448,7 +448,7 @@ class FunctionOptimizerContext { // to all side-effectful ops inside the function body. // // Invalidated function call node name -> Inlined side-effectful nodes - gtl::FlatMap> control_overrides_; + absl::flat_hash_map> control_overrides_; // Use graph view to find active outputs of the function caller nodes. GraphView graph_view_; @@ -472,10 +472,10 @@ const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx, return ctx.function_library().Find(node.op()); } -gtl::FlatSet GetActiveOutputs(const NodeDef& node, - const FunctionOptimizerContext& ctx, - int size_hint = 0) { - gtl::FlatSet active_outputs; +absl::flat_hash_set GetActiveOutputs(const NodeDef& node, + const FunctionOptimizerContext& ctx, + int size_hint = 0) { + absl::flat_hash_set active_outputs; active_outputs.reserve(static_cast(size_hint)); // 1. Output can be consumed by the other graph node. @@ -508,7 +508,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func, // number of output args is the same as number of possible function caller // node outputs. int num_outputs = func.signature().output_arg_size(); - const gtl::FlatSet active_outputs = + const absl::flat_hash_set active_outputs = GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs); return active_outputs.size() != num_outputs; @@ -519,7 +519,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func, FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib, const GraphDef& optimized_graph) { FunctionLibraryDefinition pruned_flib = - ReachableFunctionLibraryDefinition(flib, optimized_graph); + flib.ReachableDefinitions(optimized_graph); int pruned_functions = static_cast(pruned_flib.num_functions()) - static_cast(flib.num_functions()); @@ -534,8 +534,8 @@ FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib, Status PushDownConstInputs(const NodeDef& func_node, const FunctionOptimizerContext& ctx, GrapplerFunctionItem* item, - gtl::FlatSet* const_inputs, - gtl::FlatSet* control_deps) { + absl::flat_hash_set* const_inputs, + absl::flat_hash_set* control_deps) { // Record node control dependencies in the control_deps set. const auto record_control_deps = [&](const NodeDef* const_input) { for (int i = const_input->input_size() - 1; i >= 0; --i) { @@ -585,7 +585,7 @@ void RemovePushedDownConstInputs(const FunctionSpecialization& specialization, // Attach control dependencies of pushed down const input to the caller node. if (!specialization.control_deps.empty()) { - gtl::FlatSet existing_control_deps; + absl::flat_hash_set existing_control_deps; for (const string& input : keep_inputs) { existing_control_deps.insert(AsControlDependency(NodeName(input))); @@ -797,8 +797,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, // Push const inputs into the function body, and keep track of their control // dependencies. - gtl::FlatSet const_inputs; - gtl::FlatSet control_deps; + absl::flat_hash_set const_inputs; + absl::flat_hash_set control_deps; TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs, &control_deps)); @@ -1005,7 +1005,7 @@ Status InlineDirectFunctionCall(const NodeDef& func_node, // Mapping from input placeholder name to function input position. int idx = 0; - std::unordered_map input_placeholders_idx; + absl::flat_hash_map input_placeholders_idx; for (const InputArgExpansion& input_arg : item.inputs()) { for (const string& placeholder : input_arg.placeholders) { input_placeholders_idx[placeholder] = idx++; @@ -1699,7 +1699,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item, if (!ctx.control_overrides().empty()) { for (NodeDef& node : *optimized_graph->mutable_node()) { // Keep track of new control inputs to the node. - gtl::FlatSet add_ctrl_inputs; + absl::flat_hash_set add_ctrl_inputs; // Remove all invalidated control inputs. for (int idx = 0; idx < node.input_size(); /* see below */) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 67699b093d4..a84bb1d62f7 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -427,6 +427,14 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, VLOG(1) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); + // Constructs a FunctionLibraryDefinition with functions that are reachable + // from the nodes of the graph. + const auto minimized_flib = + [](const GraphDef& graph) -> FunctionLibraryDefinition { + return FunctionLibraryDefinition(OpRegistry::Global(), graph.library()) + .ReachableDefinitions(graph); + }; + // 0. Original graph might contain a huge function library, that is mostly // unused. This library copied over by each individual Grappler optimizer, // which adds a huge overhead. Before starting optimization passes we just @@ -436,11 +444,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef trimmed_graph; // do not copy graph with a potentially huge library *trimmed_graph.mutable_node() = item.graph.node(); *trimmed_graph.mutable_versions() = item.graph.versions(); - *trimmed_graph.mutable_library() = - grappler::ReachableFunctionLibraryDefinition( - FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), - item.graph) - .ToProto(); + *trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto(); GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph)); @@ -472,10 +476,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } // 2. Optimize functions reachable from the optimized graph. - FunctionLibraryDefinition flib = ReachableFunctionLibraryDefinition( - FunctionLibraryDefinition(OpRegistry::Global(), - optimized_graph->library()), - *optimized_graph); + FunctionLibraryDefinition flib = minimized_flib(*optimized_graph); // Find functions for which we might need to compute a gradient at runtime. absl::flat_hash_set differentiable_functions; diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index cd69cf895c6..89417f85c23 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -178,6 +178,9 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -196,6 +199,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index f2894a942bd..7c2180ae40d 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -14,8 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/functions.h" -#include - +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" -#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { @@ -76,16 +76,6 @@ Status ResolveFunctionBodyNodeAttrPlaceholders( } // namespace -FunctionLibraryDefinition ReachableFunctionLibraryDefinition( - const FunctionLibraryDefinition& flib, const GraphDef& graph) { - return flib.ReachableDefinitions(graph); -} - -FunctionLibraryDefinition ReachableFunctionLibraryDefinition( - const FunctionLibraryDefinition& flib, const FunctionDef& func) { - return flib.ReachableDefinitions(func); -} - void GrapplerFunctionConnectivity::RegisterInputArgExpansion( InputArgExpansion input_arg_expansion) { string input_name = input_arg_expansion.input_name; @@ -94,7 +84,7 @@ void GrapplerFunctionConnectivity::RegisterInputArgExpansion( for (int i = 0; i < placeholders.size(); ++i) { const string& placeholder = input_arg_expansion.placeholders[i]; input_arg_placeholders_.insert( - {placeholder, InputArgPlaceholder{input_name, /*input_position=*/i}}); + {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}}); } input_arg_expansions_.insert( {std::move(input_name), std::move(input_arg_expansion)}); @@ -193,7 +183,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( // If position is not defined expand node output range for (int i = output_range.first; i < output_range.second; ++i) { graph_def_inputs->push_back( - i == 0 ? node_name : strings::StrCat(node_name, ":", i)); + i == 0 ? node_name : absl::StrCat(node_name, ":", i)); } } else { if (position > (output_range.second - output_range.first)) { @@ -203,7 +193,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( } int pos = output_range.first + position; graph_def_inputs->push_back( - pos == 0 ? node_name : strings::StrCat(node_name, ":", pos)); + pos == 0 ? node_name : absl::StrCat(node_name, ":", pos)); } return Status::OK(); @@ -232,39 +222,39 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs( Status GrapplerFunctionConnectivity::AsFunctionDefInput( const string& graph_def_input, string* func_def_input) const { - using gtl::FindOrNull; - if (IsControlInput(graph_def_input)) { *func_def_input = graph_def_input; return Status::OK(); } - int position; - string node_name = ParseNodeName(graph_def_input, &position); - CHECK_GE(position, 0); + const TensorId tensor = ParseTensorName(graph_def_input); + DCHECK_GE(tensor.index(), 0); + + const absl::string_view node_name = tensor.node(); + const int index = tensor.index(); // Check if it's an input arg placeholder - if (position == 0) { - const InputArgPlaceholder* placeholder = - FindOrNull(input_arg_placeholders_, node_name); - if (placeholder != nullptr) { - *func_def_input = strings::StrCat(placeholder->input_name, ":", - placeholder->input_position); + if (tensor.index() == 0) { + const auto is_input_placeholder = input_arg_placeholders_.find(node_name); + if (is_input_placeholder != input_arg_placeholders_.end()) { + const InputArgPlaceholder& placeholder = is_input_placeholder->second; + *func_def_input = + absl::StrCat(placeholder.input_name, ":", placeholder.input_index); return Status::OK(); } } // It must be output from one of the function body nodes - const tensorflow::NameRangeMap* outputs_range_map = - FindOrNull(function_body_outputs_, node_name); - if (outputs_range_map != nullptr) { - for (const auto& el : *outputs_range_map) { + const auto is_body_output = function_body_outputs_.find(tensor.node()); + if (is_body_output != function_body_outputs_.end()) { + const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second; + + for (const auto& el : outputs_range_map) { const auto& output_name = el.first; const auto& output_range = el.second; - if (position >= output_range.first && position < output_range.second) { - int pos = position - output_range.first; - *func_def_input = - strings::StrCat(node_name, ":", output_name, ":", pos); + if (index >= output_range.first && index < output_range.second) { + int pos = index - output_range.first; + *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos); return Status::OK(); } } @@ -426,7 +416,7 @@ bool IsParametrized(const FunctionDef& func) { Status InstantiationTypeParameters( const FunctionDef& func, const AttrSlice& func_instantiation_attr, - std::unordered_map* type_parameters) { + absl::flat_hash_map* type_parameters) { if (!type_parameters->empty()) { return errors::InvalidArgument("Type parameters output map must be empty"); } @@ -454,7 +444,7 @@ Status InstantiationTypeParameters( Status InstantiationBodyParameters( const FunctionDef& func, const AttrSlice& func_instantiation_attr, - std::unordered_map* body_parameters) { + absl::flat_hash_map* body_parameters) { if (!body_parameters->empty()) { return errors::InvalidArgument("Body parameters output map must be empty"); } @@ -514,8 +504,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, // Function body shares the library with the graph that instantiated it. We do // not need a full copy of the function library, just the reachable subset. - *function_body.mutable_library() = - ReachableFunctionLibraryDefinition(flib, func).ToProto(); + *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto(); VLOG(3) << absl::Substitute( "Deleted $0 unreachable functions from the Grappler function item " @@ -645,7 +634,7 @@ Status RegisterGrapplerFunctionConnectivity( return Status::OK(); } -Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, +Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, GrapplerFunctionItem* item) { if (!IsConstant(input_const)) { return errors::InvalidArgument("Input node ", input_const.name(), @@ -657,7 +646,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, // Find input arg expansion and input placeholder position in it for the // given function input position. InputArgExpansion* input_arg_expansion = nullptr; - int placeholder_idx = input_position; + int placeholder_idx = input_index; for (InputArgExpansion& input : inputs) { if (placeholder_idx < input.placeholders.size()) { @@ -668,9 +657,8 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, } if (input_arg_expansion == nullptr) { - return errors::InvalidArgument( - "Input placeholder not found: input_position=", input_position, - " function=", item->id); + return errors::InvalidArgument("Input placeholder not found: input_index=", + input_index, " function=", item->id); } // Delete placeholder from input expansion. @@ -699,7 +687,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, return Status::OK(); } -Status RemoveUnusedOutputs(const gtl::FlatSet& active_outputs, +Status RemoveUnusedOutputs(const absl::flat_hash_set& active_outputs, GrapplerFunctionItem* item, std::vector>* output_mapping) { DCHECK(output_mapping->empty()); @@ -713,7 +701,7 @@ Status RemoveUnusedOutputs(const gtl::FlatSet& active_outputs, } } - gtl::FlatSet unused_output_args; + absl::flat_hash_set unused_output_args; const auto is_unused_output_arg = [&](const OutputArgExpansion& output) { return unused_output_args.find(&output) != unused_output_args.end(); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 038cf5f527e..ce8a3e5ac78 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -18,7 +18,9 @@ limitations under the License. #include #include -#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -30,13 +32,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Returns a copy of FunctionLibraryDefinition with subset of functions that are -// reachable from the nodes of the graph. -FunctionLibraryDefinition ReachableFunctionLibraryDefinition( - const FunctionLibraryDefinition& flib, const GraphDef& graph); -FunctionLibraryDefinition ReachableFunctionLibraryDefinition( - const FunctionLibraryDefinition& flib, const FunctionDef& func); - // Depending on the function instantiation attributes, input argument to the // function might be a single tensor, list of tensors of the same type, or a // list of tensors of different types. @@ -81,12 +76,12 @@ class GrapplerFunctionConnectivity { void RegisterFunctionBodyOutputs(const string& node_name, tensorflow::NameRangeMap&& outputs); - // Expand input encoded in FunctionDef format (name[:output][:position]) into + // Expands input encoded in FunctionDef format (name[:output][:position]) into // multiple inputs in GraphDef format (name[:position]). Status ExpandFunctionDefInput(const string& func_def_input, std::vector* graph_def_inputs) const; - // Update Node inputs from FunctionDef to GraphDef format. + // Updates Node inputs from FunctionDef to GraphDef format. Status ExpandNodeInputs(NodeDef* function_body_node) const; // When expanding inputs in function def format, single input might be @@ -96,29 +91,31 @@ class GrapplerFunctionConnectivity { // instantiation attributes and length of input args (and node def outputs) is // known. - // Map from GraphDef input format to FunctionDef input format using registered - // input arg expansion and function body outputs. + // Converts input name from GraphDef format (name[:position]) to the + // FunctionDef input format (name[:output][:position]) using registered input + // arg expansion and function body outputs. Status AsFunctionDefInput(const string& graph_def_input, string* func_def_input) const; - // Update Node inputs from GraphDef to FunctionDef format. + // Updates Node inputs from GraphDef to FunctionDef format. Status AsFunctionDefNode(NodeDef* function_body_node) const; private: // Mapping from input name to input arg expansion. - std::unordered_map input_arg_expansions_; + absl::flat_hash_map input_arg_expansions_; // Mapping from function body node name to output names range map. - std::unordered_map function_body_outputs_; + absl::flat_hash_map function_body_outputs_; + // For each placeholder added to the function instantiation graph, we keep a + // mapping back to the function input argument name and index. struct InputArgPlaceholder { - string input_name; // Name of the function input argument. - int input_position; // Index of a tensor in the function input argument - // expansion, it can be greater than `0` if input - // argument is a list of tensors (aka list(type)). + string input_name; // Name of the function input argument. + int input_index; // Index of a tensor in the function input argument + // expansion, it can be greater than `0` if input + // argument is a list of tensors (aka list(type)). }; - // Mapping from input arg placeholder to the function input tensor. - std::unordered_map input_arg_placeholders_; + absl::flat_hash_map input_arg_placeholders_; }; // Get Function type attributes using attributes of a node that instantiated @@ -172,7 +169,8 @@ class GrapplerFunctionItem : public GrapplerItem { friend Status ReplaceInputWithConst(const NodeDef&, int, GrapplerFunctionItem*); friend Status RemoveUnusedOutputs( - const gtl::FlatSet& active_outputs, GrapplerFunctionItem* item, + const absl::flat_hash_set& active_outputs, + GrapplerFunctionItem* item, std::vector>* output_mapping); GrapplerFunctionItem(string func_name, string description, @@ -191,7 +189,7 @@ class GrapplerFunctionItem : public GrapplerItem { std::set input_arg_placeholders_; - bool is_stateful_; + bool is_stateful_ = false; }; // Check if function input/output types are fully defined only at instantiation @@ -210,14 +208,14 @@ bool IsParametrized(const FunctionDef& func); // caller node. Return error if type can't be resolved. Status InstantiationTypeParameters( const FunctionDef& func, const AttrSlice& func_instantiation_attr, - std::unordered_map* type_parameters); + absl::flat_hash_map* type_parameters); // Resolve function instantiation body parameters (values for the function body // attr placeholders) from the attributes of the caller node. Return error if // type can't be resolved. Status InstantiationBodyParameters( const FunctionDef& func, const AttrSlice& func_instantiation_attr, - std::unordered_map* body_parameters); + absl::flat_hash_map* body_parameters); // Register GrapplerFunctionItem input arg expansion and function body outputs // in the GrapplerFunctionConnectivity. Use function library definition to @@ -227,7 +225,7 @@ Status RegisterGrapplerFunctionConnectivity( GrapplerFunctionConnectivity* connectivity); // Replace one of the function inputs with a constant. -Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, +Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, GrapplerFunctionItem* item); // Remove function output arguments that do not have any active outputs (output @@ -236,7 +234,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, // potentially be connected to the same output argument (in case of tensor list // outputs). Add output mapping for all active outputs that changed it's output // position (std::pair). -Status RemoveUnusedOutputs(const gtl::FlatSet& active_outputs, +Status RemoveUnusedOutputs(const absl::flat_hash_set& active_outputs, GrapplerFunctionItem* item, std::vector>* output_mapping); diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 5923850eca6..29d6100d237 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/utils/functions.h" + +#include "absl/container/flat_hash_map.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -77,7 +79,7 @@ TEST_F(FunctionsTest, InstantiationParameters) { func_instantiation_attr["B"].set_type(DT_INT32); func_instantiation_attr["C"].set_type(DT_DOUBLE); - std::unordered_map type_parameters; + absl::flat_hash_map type_parameters; TF_EXPECT_OK(InstantiationTypeParameters( func, AttrSlice(&func_instantiation_attr), &type_parameters)); @@ -86,7 +88,7 @@ TEST_F(FunctionsTest, InstantiationParameters) { EXPECT_EQ(DT_INT32, type_parameters["B"]); EXPECT_EQ(DT_DOUBLE, type_parameters["C"]); - std::unordered_map body_parameters; + absl::flat_hash_map body_parameters; TF_EXPECT_OK(InstantiationBodyParameters( func, AttrSlice(&func_instantiation_attr), &body_parameters));