[Grappler] Cleanup function/function_optimizer

1. Migrate to absl containers
2. Compute function signature hash without allocating temporary sorted containers

PiperOrigin-RevId: 227587343
This commit is contained in:
Eugene Zhulenev 2019-01-02 15:00:14 -08:00 committed by TensorFlower Gardener
parent 790a635a04
commit a7a3bbfcbf
7 changed files with 149 additions and 154 deletions

View File

@ -151,6 +151,8 @@ cc_library(
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions", "//tensorflow/core/grappler/utils:functions",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],

View File

@ -15,10 +15,11 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/function_optimizer.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include <unordered_map>
#include <vector> #include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_replace.h" #include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
@ -163,10 +164,10 @@ struct FunctionSpecializationSignature {
string func_name; string func_name;
bool is_in_fetch_set; bool is_in_fetch_set;
gtl::FlatSet<OutputPort> active_outputs; absl::flat_hash_set<OutputPort> active_outputs;
std::unordered_map<string, DataType> type_parameters; absl::flat_hash_map<string, DataType> type_parameters;
std::unordered_map<string, AttrValue> body_parameters; absl::flat_hash_map<string, AttrValue> body_parameters;
std::unordered_map<InputPort, string> const_inputs; absl::flat_hash_map<InputPort, string> const_inputs;
bool operator==(const FunctionSpecializationSignature& other) const { bool operator==(const FunctionSpecializationSignature& other) const {
bool equals = func_name == other.func_name && bool equals = func_name == other.func_name &&
@ -189,62 +190,59 @@ struct FunctionSpecializationSignature {
return true; return true;
} }
// TODO(ezhulenev): Migrate to AbslHashValue. template <typename H>
// TODO(ezhulenev): Optimize performance by computing hashes of unordered friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
// values first, and then compute a hash of sorted hashes. H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
struct Hash {
uint64 operator()(FunctionSpecializationSignature const& s) const {
uint64 h = Hash64(s.func_name);
h = Hash64Combine(std::hash<bool>()(s.is_in_fetch_set), h);
// Use std::set/std::map for deterministic iteration order. // First pre-compute hashes for all values in collections with
// non-deterministic iteration order.
std::vector<uint64> hashes;
hashes.reserve(s.active_outputs.size() //
+ s.type_parameters.size() * 2 //
+ s.body_parameters.size() * 2 //
+ s.const_inputs.size() * 2);
std::set<OutputPort> active_outputs(s.active_outputs.begin(), absl::c_transform(s.active_outputs, std::back_inserter(hashes),
s.active_outputs.end()); hash<OutputPort>());
for (const auto& active_output : active_outputs) {
h = Hash64Combine(std::hash<int>()(active_output), h);
}
std::map<string, DataType> types(s.type_parameters.begin(), using TypeParam = std::pair<const string, DataType>;
s.type_parameters.end()); absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
for (const auto& pair : types) {
AttrValue attr_value; AttrValue attr_value;
attr_value.set_type(pair.second); attr_value.set_type(type_param.second);
h = Hash64Combine(Hash64(pair.first), h); hashes.push_back(Hash64(type_param.first));
h = Hash64Combine(AttrValueHash(attr_value), h); hashes.push_back(AttrValueHash(attr_value));
} });
std::map<string, AttrValue> body(s.body_parameters.begin(), using BodyParam = std::pair<const string, AttrValue>;
s.body_parameters.end()); absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
for (const auto& pair : body) { hashes.push_back(Hash64(body_param.first));
h = Hash64Combine(Hash64(pair.first), h); hashes.push_back(FastAttrValueHash(body_param.second));
h = Hash64Combine(FastAttrValueHash(pair.second), h); });
}
std::map<InputPort, string> inputs(s.const_inputs.begin(), using ConstInput = std::pair<const InputPort, string>;
s.const_inputs.end()); absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
for (const auto& pair : inputs) { hashes.push_back(hash<InputPort>()(const_input.first));
h = Hash64Combine(std::hash<int>()(pair.first), h); hashes.push_back(Hash64(const_input.second));
h = Hash64Combine(Hash64(pair.second), h); });
}
return h; // Combine all pre-computed hashes in a deterministic order.
absl::c_sort(hashes);
return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
} }
}; };
};
struct FunctionSpecialization { struct FunctionSpecialization {
string specialized_func_name; string specialized_func_name;
// True if the function caller node is in GrapplerItem fetch set. // True if the function caller node is in GrapplerItem fetch set.
bool is_in_fetch_set; bool is_in_fetch_set;
// Names of the tensors that were pushed down into the function body. // Names of the tensors that were pushed down into the function body.
gtl::FlatSet<string> const_inputs; absl::flat_hash_set<string> const_inputs;
// Control dependencies of pushed down const inputs have to be attached to // Control dependencies of pushed down const inputs have to be attached to
// function caller node. // function caller node.
gtl::FlatSet<string> control_deps; absl::flat_hash_set<string> control_deps;
// Output tensors (ports) that consumed by other nodes in the graph or in a // Output tensors (ports) that consumed by other nodes in the graph or in a
// GrapplerItem fetch set. // GrapplerItem fetch set.
gtl::FlatSet<int> active_outputs; absl::flat_hash_set<int> active_outputs;
// Mapping from original function output port to the output port of // Mapping from original function output port to the output port of
// specialized function. If function specialization changes the number of // specialized function. If function specialization changes the number of
// function outputs it's required to update all node consumers. // function outputs it's required to update all node consumers.
@ -285,12 +283,13 @@ class FunctionOptimizerContext {
return flr_; return flr_;
} }
const gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>& const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const { tensor_mapping() const {
return tensor_mapping_; return tensor_mapping_;
} }
const gtl::FlatMap<string, std::vector<string>>& control_overrides() const { const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
const {
return control_overrides_; return control_overrides_;
} }
@ -298,7 +297,9 @@ class FunctionOptimizerContext {
const string& grappler_item_id() const { return grappler_item_id_; } const string& grappler_item_id() const { return grappler_item_id_; }
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; } const absl::flat_hash_set<string>& fetch_tensors() const {
return fetch_tensors_;
}
const DeviceSet* devices() const { const DeviceSet* devices() const {
// Create fake devices lazily only if we need a DeviceSet. // Create fake devices lazily only if we need a DeviceSet.
@ -365,7 +366,7 @@ class FunctionOptimizerContext {
private: private:
void InitializeTrulyConstNodes(const GrapplerItem& item) { void InitializeTrulyConstNodes(const GrapplerItem& item) {
gtl::FlatSet<string> feed_nodes; absl::flat_hash_set<string> feed_nodes;
for (const auto& feed : item.feed) { for (const auto& feed : item.feed) {
feed_nodes.insert(NodeName(feed.first)); feed_nodes.insert(NodeName(feed.first));
} }
@ -411,7 +412,7 @@ class FunctionOptimizerContext {
FunctionLibraryRuntime* flr_ = nullptr; FunctionLibraryRuntime* flr_ = nullptr;
// Fully defined names of the devices available to the GrapplerItem. // Fully defined names of the devices available to the GrapplerItem.
const gtl::FlatSet<string> available_device_names_; const absl::flat_hash_set<string> available_device_names_;
// List of available `FakedDevices` (lazily initialized, see devices()). // List of available `FakedDevices` (lazily initialized, see devices()).
mutable std::vector<std::unique_ptr<Device>> available_devices_; mutable std::vector<std::unique_ptr<Device>> available_devices_;
@ -421,16 +422,15 @@ class FunctionOptimizerContext {
mutable DeviceSet available_device_set_; mutable DeviceSet available_device_set_;
// Nodes that are Const and not in feed. // Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_; absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
// Specialized functions. // Specialized functions.
std::unordered_map<FunctionSpecializationSignature, absl::flat_hash_map<FunctionSpecializationSignature,
const FunctionSpecialization, const FunctionSpecialization>
FunctionSpecializationSignature::Hash>
specialized_functions_; specialized_functions_;
// GrapplerItem.fetch is a vector of tensors. // GrapplerItem.fetch is a vector of tensors.
gtl::FlatSet<string> fetch_tensors_; // format: node_name:port absl::flat_hash_set<string> fetch_tensors_; // format: node_name:port
gtl::FlatSet<string> fetch_nodes_; // format: node_name absl::flat_hash_set<string> fetch_nodes_; // format: node_name
// After function inlining and specialization, the optimized graph might be in // After function inlining and specialization, the optimized graph might be in
// invalid state, nodes can read from non-existing function call nodes that // invalid state, nodes can read from non-existing function call nodes that
@ -439,7 +439,7 @@ class FunctionOptimizerContext {
// //
// Tensor mapping that has to be applied to the graph after all functions // Tensor mapping that has to be applied to the graph after all functions
// optimizations (invalidated tensor id -> optimized graph tensor id). // optimizations (invalidated tensor id -> optimized graph tensor id).
gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher> absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
tensor_mapping_; tensor_mapping_;
// When we inline a function into the optimized graph, we no longer have the // When we inline a function into the optimized graph, we no longer have the
@ -448,7 +448,7 @@ class FunctionOptimizerContext {
// to all side-effectful ops inside the function body. // to all side-effectful ops inside the function body.
// //
// Invalidated function call node name -> Inlined side-effectful nodes // Invalidated function call node name -> Inlined side-effectful nodes
gtl::FlatMap<string, std::vector<string>> control_overrides_; absl::flat_hash_map<string, std::vector<string>> control_overrides_;
// Use graph view to find active outputs of the function caller nodes. // Use graph view to find active outputs of the function caller nodes.
GraphView graph_view_; GraphView graph_view_;
@ -472,10 +472,10 @@ const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
return ctx.function_library().Find(node.op()); return ctx.function_library().Find(node.op());
} }
gtl::FlatSet<int> GetActiveOutputs(const NodeDef& node, absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
const FunctionOptimizerContext& ctx, const FunctionOptimizerContext& ctx,
int size_hint = 0) { int size_hint = 0) {
gtl::FlatSet<int> active_outputs; absl::flat_hash_set<int> active_outputs;
active_outputs.reserve(static_cast<size_t>(size_hint)); active_outputs.reserve(static_cast<size_t>(size_hint));
// 1. Output can be consumed by the other graph node. // 1. Output can be consumed by the other graph node.
@ -508,7 +508,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
// number of output args is the same as number of possible function caller // number of output args is the same as number of possible function caller
// node outputs. // node outputs.
int num_outputs = func.signature().output_arg_size(); int num_outputs = func.signature().output_arg_size();
const gtl::FlatSet<int> active_outputs = const absl::flat_hash_set<int> active_outputs =
GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs); GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
return active_outputs.size() != num_outputs; return active_outputs.size() != num_outputs;
@ -519,7 +519,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib, FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
const GraphDef& optimized_graph) { const GraphDef& optimized_graph) {
FunctionLibraryDefinition pruned_flib = FunctionLibraryDefinition pruned_flib =
ReachableFunctionLibraryDefinition(flib, optimized_graph); flib.ReachableDefinitions(optimized_graph);
int pruned_functions = static_cast<int>(pruned_flib.num_functions()) - int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
static_cast<int>(flib.num_functions()); static_cast<int>(flib.num_functions());
@ -534,8 +534,8 @@ FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
Status PushDownConstInputs(const NodeDef& func_node, Status PushDownConstInputs(const NodeDef& func_node,
const FunctionOptimizerContext& ctx, const FunctionOptimizerContext& ctx,
GrapplerFunctionItem* item, GrapplerFunctionItem* item,
gtl::FlatSet<string>* const_inputs, absl::flat_hash_set<string>* const_inputs,
gtl::FlatSet<string>* control_deps) { absl::flat_hash_set<string>* control_deps) {
// Record node control dependencies in the control_deps set. // Record node control dependencies in the control_deps set.
const auto record_control_deps = [&](const NodeDef* const_input) { const auto record_control_deps = [&](const NodeDef* const_input) {
for (int i = const_input->input_size() - 1; i >= 0; --i) { for (int i = const_input->input_size() - 1; i >= 0; --i) {
@ -585,7 +585,7 @@ void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
// Attach control dependencies of pushed down const input to the caller node. // Attach control dependencies of pushed down const input to the caller node.
if (!specialization.control_deps.empty()) { if (!specialization.control_deps.empty()) {
gtl::FlatSet<string> existing_control_deps; absl::flat_hash_set<string> existing_control_deps;
for (const string& input : keep_inputs) { for (const string& input : keep_inputs) {
existing_control_deps.insert(AsControlDependency(NodeName(input))); existing_control_deps.insert(AsControlDependency(NodeName(input)));
@ -797,8 +797,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Push const inputs into the function body, and keep track of their control // Push const inputs into the function body, and keep track of their control
// dependencies. // dependencies.
gtl::FlatSet<string> const_inputs; absl::flat_hash_set<string> const_inputs;
gtl::FlatSet<string> control_deps; absl::flat_hash_set<string> control_deps;
TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs, TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
&control_deps)); &control_deps));
@ -1005,7 +1005,7 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
// Mapping from input placeholder name to function input position. // Mapping from input placeholder name to function input position.
int idx = 0; int idx = 0;
std::unordered_map<string, int> input_placeholders_idx; absl::flat_hash_map<string, int> input_placeholders_idx;
for (const InputArgExpansion& input_arg : item.inputs()) { for (const InputArgExpansion& input_arg : item.inputs()) {
for (const string& placeholder : input_arg.placeholders) { for (const string& placeholder : input_arg.placeholders) {
input_placeholders_idx[placeholder] = idx++; input_placeholders_idx[placeholder] = idx++;
@ -1699,7 +1699,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
if (!ctx.control_overrides().empty()) { if (!ctx.control_overrides().empty()) {
for (NodeDef& node : *optimized_graph->mutable_node()) { for (NodeDef& node : *optimized_graph->mutable_node()) {
// Keep track of new control inputs to the node. // Keep track of new control inputs to the node.
gtl::FlatSet<string> add_ctrl_inputs; absl::flat_hash_set<string> add_ctrl_inputs;
// Remove all invalidated control inputs. // Remove all invalidated control inputs.
for (int idx = 0; idx < node.input_size(); /* see below */) { for (int idx = 0; idx < node.input_size(); /* see below */) {

View File

@ -427,6 +427,14 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
VLOG(1) << "Starting optimization for grappler item: " << item.id; VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear(); optimization_results_.clear();
// Constructs a FunctionLibraryDefinition with functions that are reachable
// from the nodes of the graph.
const auto minimized_flib =
[](const GraphDef& graph) -> FunctionLibraryDefinition {
return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
.ReachableDefinitions(graph);
};
// 0. Original graph might contain a huge function library, that is mostly // 0. Original graph might contain a huge function library, that is mostly
// unused. This library copied over by each individual Grappler optimizer, // unused. This library copied over by each individual Grappler optimizer,
// which adds a huge overhead. Before starting optimization passes we just // which adds a huge overhead. Before starting optimization passes we just
@ -436,11 +444,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef trimmed_graph; // do not copy graph with a potentially huge library GraphDef trimmed_graph; // do not copy graph with a potentially huge library
*trimmed_graph.mutable_node() = item.graph.node(); *trimmed_graph.mutable_node() = item.graph.node();
*trimmed_graph.mutable_versions() = item.graph.versions(); *trimmed_graph.mutable_versions() = item.graph.versions();
*trimmed_graph.mutable_library() = *trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
grappler::ReachableFunctionLibraryDefinition(
FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()),
item.graph)
.ToProto();
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph)); GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
@ -472,10 +476,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
// 2. Optimize functions reachable from the optimized graph. // 2. Optimize functions reachable from the optimized graph.
FunctionLibraryDefinition flib = ReachableFunctionLibraryDefinition( FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
FunctionLibraryDefinition(OpRegistry::Global(),
optimized_graph->library()),
*optimized_graph);
// Find functions for which we might need to compute a gradient at runtime. // Find functions for which we might need to compute a gradient at runtime.
absl::flat_hash_set<string> differentiable_functions; absl::flat_hash_set<string> differentiable_functions;

View File

@ -178,6 +178,9 @@ cc_library(
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
@ -196,6 +199,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@com_google_absl//absl/container:flat_hash_map",
], ],
) )

View File

@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/functions.h"
#include <unordered_map> #include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow { namespace tensorflow {
@ -76,16 +76,6 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace } // namespace
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const GraphDef& graph) {
return flib.ReachableDefinitions(graph);
}
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const FunctionDef& func) {
return flib.ReachableDefinitions(func);
}
void GrapplerFunctionConnectivity::RegisterInputArgExpansion( void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
InputArgExpansion input_arg_expansion) { InputArgExpansion input_arg_expansion) {
string input_name = input_arg_expansion.input_name; string input_name = input_arg_expansion.input_name;
@ -94,7 +84,7 @@ void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
for (int i = 0; i < placeholders.size(); ++i) { for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i]; const string& placeholder = input_arg_expansion.placeholders[i];
input_arg_placeholders_.insert( input_arg_placeholders_.insert(
{placeholder, InputArgPlaceholder{input_name, /*input_position=*/i}}); {placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
} }
input_arg_expansions_.insert( input_arg_expansions_.insert(
{std::move(input_name), std::move(input_arg_expansion)}); {std::move(input_name), std::move(input_arg_expansion)});
@ -193,7 +183,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
// If position is not defined expand node output range // If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) { for (int i = output_range.first; i < output_range.second; ++i) {
graph_def_inputs->push_back( graph_def_inputs->push_back(
i == 0 ? node_name : strings::StrCat(node_name, ":", i)); i == 0 ? node_name : absl::StrCat(node_name, ":", i));
} }
} else { } else {
if (position > (output_range.second - output_range.first)) { if (position > (output_range.second - output_range.first)) {
@ -203,7 +193,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
} }
int pos = output_range.first + position; int pos = output_range.first + position;
graph_def_inputs->push_back( graph_def_inputs->push_back(
pos == 0 ? node_name : strings::StrCat(node_name, ":", pos)); pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
} }
return Status::OK(); return Status::OK();
@ -232,39 +222,39 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
Status GrapplerFunctionConnectivity::AsFunctionDefInput( Status GrapplerFunctionConnectivity::AsFunctionDefInput(
const string& graph_def_input, string* func_def_input) const { const string& graph_def_input, string* func_def_input) const {
using gtl::FindOrNull;
if (IsControlInput(graph_def_input)) { if (IsControlInput(graph_def_input)) {
*func_def_input = graph_def_input; *func_def_input = graph_def_input;
return Status::OK(); return Status::OK();
} }
int position; const TensorId tensor = ParseTensorName(graph_def_input);
string node_name = ParseNodeName(graph_def_input, &position); DCHECK_GE(tensor.index(), 0);
CHECK_GE(position, 0);
const absl::string_view node_name = tensor.node();
const int index = tensor.index();
// Check if it's an input arg placeholder // Check if it's an input arg placeholder
if (position == 0) { if (tensor.index() == 0) {
const InputArgPlaceholder* placeholder = const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
FindOrNull(input_arg_placeholders_, node_name); if (is_input_placeholder != input_arg_placeholders_.end()) {
if (placeholder != nullptr) { const InputArgPlaceholder& placeholder = is_input_placeholder->second;
*func_def_input = strings::StrCat(placeholder->input_name, ":", *func_def_input =
placeholder->input_position); absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
return Status::OK(); return Status::OK();
} }
} }
// It must be output from one of the function body nodes // It must be output from one of the function body nodes
const tensorflow::NameRangeMap* outputs_range_map = const auto is_body_output = function_body_outputs_.find(tensor.node());
FindOrNull(function_body_outputs_, node_name); if (is_body_output != function_body_outputs_.end()) {
if (outputs_range_map != nullptr) { const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
for (const auto& el : *outputs_range_map) {
for (const auto& el : outputs_range_map) {
const auto& output_name = el.first; const auto& output_name = el.first;
const auto& output_range = el.second; const auto& output_range = el.second;
if (position >= output_range.first && position < output_range.second) { if (index >= output_range.first && index < output_range.second) {
int pos = position - output_range.first; int pos = index - output_range.first;
*func_def_input = *func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
strings::StrCat(node_name, ":", output_name, ":", pos);
return Status::OK(); return Status::OK();
} }
} }
@ -426,7 +416,7 @@ bool IsParametrized(const FunctionDef& func) {
Status InstantiationTypeParameters( Status InstantiationTypeParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr, const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, DataType>* type_parameters) { absl::flat_hash_map<string, DataType>* type_parameters) {
if (!type_parameters->empty()) { if (!type_parameters->empty()) {
return errors::InvalidArgument("Type parameters output map must be empty"); return errors::InvalidArgument("Type parameters output map must be empty");
} }
@ -454,7 +444,7 @@ Status InstantiationTypeParameters(
Status InstantiationBodyParameters( Status InstantiationBodyParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr, const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, AttrValue>* body_parameters) { absl::flat_hash_map<string, AttrValue>* body_parameters) {
if (!body_parameters->empty()) { if (!body_parameters->empty()) {
return errors::InvalidArgument("Body parameters output map must be empty"); return errors::InvalidArgument("Body parameters output map must be empty");
} }
@ -514,8 +504,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// Function body shares the library with the graph that instantiated it. We do // Function body shares the library with the graph that instantiated it. We do
// not need a full copy of the function library, just the reachable subset. // not need a full copy of the function library, just the reachable subset.
*function_body.mutable_library() = *function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
ReachableFunctionLibraryDefinition(flib, func).ToProto();
VLOG(3) << absl::Substitute( VLOG(3) << absl::Substitute(
"Deleted $0 unreachable functions from the Grappler function item " "Deleted $0 unreachable functions from the Grappler function item "
@ -645,7 +634,7 @@ Status RegisterGrapplerFunctionConnectivity(
return Status::OK(); return Status::OK();
} }
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item) { GrapplerFunctionItem* item) {
if (!IsConstant(input_const)) { if (!IsConstant(input_const)) {
return errors::InvalidArgument("Input node ", input_const.name(), return errors::InvalidArgument("Input node ", input_const.name(),
@ -657,7 +646,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
// Find input arg expansion and input placeholder position in it for the // Find input arg expansion and input placeholder position in it for the
// given function input position. // given function input position.
InputArgExpansion* input_arg_expansion = nullptr; InputArgExpansion* input_arg_expansion = nullptr;
int placeholder_idx = input_position; int placeholder_idx = input_index;
for (InputArgExpansion& input : inputs) { for (InputArgExpansion& input : inputs) {
if (placeholder_idx < input.placeholders.size()) { if (placeholder_idx < input.placeholders.size()) {
@ -668,9 +657,8 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
} }
if (input_arg_expansion == nullptr) { if (input_arg_expansion == nullptr) {
return errors::InvalidArgument( return errors::InvalidArgument("Input placeholder not found: input_index=",
"Input placeholder not found: input_position=", input_position, input_index, " function=", item->id);
" function=", item->id);
} }
// Delete placeholder from input expansion. // Delete placeholder from input expansion.
@ -699,7 +687,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
return Status::OK(); return Status::OK();
} }
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs, Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item, GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping) { std::vector<std::pair<int, int>>* output_mapping) {
DCHECK(output_mapping->empty()); DCHECK(output_mapping->empty());
@ -713,7 +701,7 @@ Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
} }
} }
gtl::FlatSet<const OutputArgExpansion*> unused_output_args; absl::flat_hash_set<const OutputArgExpansion*> unused_output_args;
const auto is_unused_output_arg = [&](const OutputArgExpansion& output) { const auto is_unused_output_arg = [&](const OutputArgExpansion& output) {
return unused_output_args.find(&output) != unused_output_args.end(); return unused_output_args.find(&output) != unused_output_args.end();

View File

@ -18,7 +18,9 @@ limitations under the License.
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function.pb.h"
@ -30,13 +32,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
// Returns a copy of FunctionLibraryDefinition with subset of functions that are
// reachable from the nodes of the graph.
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const GraphDef& graph);
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const FunctionDef& func);
// Depending on the function instantiation attributes, input argument to the // Depending on the function instantiation attributes, input argument to the
// function might be a single tensor, list of tensors of the same type, or a // function might be a single tensor, list of tensors of the same type, or a
// list of tensors of different types. // list of tensors of different types.
@ -81,12 +76,12 @@ class GrapplerFunctionConnectivity {
void RegisterFunctionBodyOutputs(const string& node_name, void RegisterFunctionBodyOutputs(const string& node_name,
tensorflow::NameRangeMap&& outputs); tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into // Expands input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]). // multiple inputs in GraphDef format (name[:position]).
Status ExpandFunctionDefInput(const string& func_def_input, Status ExpandFunctionDefInput(const string& func_def_input,
std::vector<string>* graph_def_inputs) const; std::vector<string>* graph_def_inputs) const;
// Update Node inputs from FunctionDef to GraphDef format. // Updates Node inputs from FunctionDef to GraphDef format.
Status ExpandNodeInputs(NodeDef* function_body_node) const; Status ExpandNodeInputs(NodeDef* function_body_node) const;
// When expanding inputs in function def format, single input might be // When expanding inputs in function def format, single input might be
@ -96,29 +91,31 @@ class GrapplerFunctionConnectivity {
// instantiation attributes and length of input args (and node def outputs) is // instantiation attributes and length of input args (and node def outputs) is
// known. // known.
// Map from GraphDef input format to FunctionDef input format using registered // Converts input name from GraphDef format (name[:position]) to the
// input arg expansion and function body outputs. // FunctionDef input format (name[:output][:position]) using registered input
// arg expansion and function body outputs.
Status AsFunctionDefInput(const string& graph_def_input, Status AsFunctionDefInput(const string& graph_def_input,
string* func_def_input) const; string* func_def_input) const;
// Update Node inputs from GraphDef to FunctionDef format. // Updates Node inputs from GraphDef to FunctionDef format.
Status AsFunctionDefNode(NodeDef* function_body_node) const; Status AsFunctionDefNode(NodeDef* function_body_node) const;
private: private:
// Mapping from input name to input arg expansion. // Mapping from input name to input arg expansion.
std::unordered_map<string, InputArgExpansion> input_arg_expansions_; absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
// Mapping from function body node name to output names range map. // Mapping from function body node name to output names range map.
std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_; absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
// For each placeholder added to the function instantiation graph, we keep a
// mapping back to the function input argument name and index.
struct InputArgPlaceholder { struct InputArgPlaceholder {
string input_name; // Name of the function input argument. string input_name; // Name of the function input argument.
int input_position; // Index of a tensor in the function input argument int input_index; // Index of a tensor in the function input argument
// expansion, it can be greater than `0` if input // expansion, it can be greater than `0` if input
// argument is a list of tensors (aka list(type)). // argument is a list of tensors (aka list(type)).
}; };
// Mapping from input arg placeholder to the function input tensor. // Mapping from input arg placeholder to the function input tensor.
std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_; absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
}; };
// Get Function type attributes using attributes of a node that instantiated // Get Function type attributes using attributes of a node that instantiated
@ -172,7 +169,8 @@ class GrapplerFunctionItem : public GrapplerItem {
friend Status ReplaceInputWithConst(const NodeDef&, int, friend Status ReplaceInputWithConst(const NodeDef&, int,
GrapplerFunctionItem*); GrapplerFunctionItem*);
friend Status RemoveUnusedOutputs( friend Status RemoveUnusedOutputs(
const gtl::FlatSet<int>& active_outputs, GrapplerFunctionItem* item, const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping); std::vector<std::pair<int, int>>* output_mapping);
GrapplerFunctionItem(string func_name, string description, GrapplerFunctionItem(string func_name, string description,
@ -191,7 +189,7 @@ class GrapplerFunctionItem : public GrapplerItem {
std::set<string> input_arg_placeholders_; std::set<string> input_arg_placeholders_;
bool is_stateful_; bool is_stateful_ = false;
}; };
// Check if function input/output types are fully defined only at instantiation // Check if function input/output types are fully defined only at instantiation
@ -210,14 +208,14 @@ bool IsParametrized(const FunctionDef& func);
// caller node. Return error if type can't be resolved. // caller node. Return error if type can't be resolved.
Status InstantiationTypeParameters( Status InstantiationTypeParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr, const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, DataType>* type_parameters); absl::flat_hash_map<string, DataType>* type_parameters);
// Resolve function instantiation body parameters (values for the function body // Resolve function instantiation body parameters (values for the function body
// attr placeholders) from the attributes of the caller node. Return error if // attr placeholders) from the attributes of the caller node. Return error if
// type can't be resolved. // type can't be resolved.
Status InstantiationBodyParameters( Status InstantiationBodyParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr, const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, AttrValue>* body_parameters); absl::flat_hash_map<string, AttrValue>* body_parameters);
// Register GrapplerFunctionItem input arg expansion and function body outputs // Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to // in the GrapplerFunctionConnectivity. Use function library definition to
@ -227,7 +225,7 @@ Status RegisterGrapplerFunctionConnectivity(
GrapplerFunctionConnectivity* connectivity); GrapplerFunctionConnectivity* connectivity);
// Replace one of the function inputs with a constant. // Replace one of the function inputs with a constant.
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position, Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item); GrapplerFunctionItem* item);
// Remove function output arguments that do not have any active outputs (output // Remove function output arguments that do not have any active outputs (output
@ -236,7 +234,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
// potentially be connected to the same output argument (in case of tensor list // potentially be connected to the same output argument (in case of tensor list
// outputs). Add output mapping for all active outputs that changed it's output // outputs). Add output mapping for all active outputs that changed it's output
// position (std::pair<old position, new position>). // position (std::pair<old position, new position>).
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs, Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item, GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping); std::vector<std::pair<int, int>>* output_mapping);

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/functions.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
@ -77,7 +79,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
func_instantiation_attr["B"].set_type(DT_INT32); func_instantiation_attr["B"].set_type(DT_INT32);
func_instantiation_attr["C"].set_type(DT_DOUBLE); func_instantiation_attr["C"].set_type(DT_DOUBLE);
std::unordered_map<string, DataType> type_parameters; absl::flat_hash_map<string, DataType> type_parameters;
TF_EXPECT_OK(InstantiationTypeParameters( TF_EXPECT_OK(InstantiationTypeParameters(
func, AttrSlice(&func_instantiation_attr), &type_parameters)); func, AttrSlice(&func_instantiation_attr), &type_parameters));
@ -86,7 +88,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
EXPECT_EQ(DT_INT32, type_parameters["B"]); EXPECT_EQ(DT_INT32, type_parameters["B"]);
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]); EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
std::unordered_map<string, AttrValue> body_parameters; absl::flat_hash_map<string, AttrValue> body_parameters;
TF_EXPECT_OK(InstantiationBodyParameters( TF_EXPECT_OK(InstantiationBodyParameters(
func, AttrSlice(&func_instantiation_attr), &body_parameters)); func, AttrSlice(&func_instantiation_attr), &body_parameters));