[Grappler] Cleanup function/function_optimizer

1. Migrate to absl containers
2. Compute function signature hash without allocating temporary sorted containers

PiperOrigin-RevId: 227587343
This commit is contained in:
Eugene Zhulenev 2019-01-02 15:00:14 -08:00 committed by TensorFlower Gardener
parent 790a635a04
commit a7a3bbfcbf
7 changed files with 149 additions and 154 deletions

View File

@ -151,6 +151,8 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:functions",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],

View File

@ -15,10 +15,11 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include <unordered_map>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/substitute.h"
@ -163,10 +164,10 @@ struct FunctionSpecializationSignature {
string func_name;
bool is_in_fetch_set;
gtl::FlatSet<OutputPort> active_outputs;
std::unordered_map<string, DataType> type_parameters;
std::unordered_map<string, AttrValue> body_parameters;
std::unordered_map<InputPort, string> const_inputs;
absl::flat_hash_set<OutputPort> active_outputs;
absl::flat_hash_map<string, DataType> type_parameters;
absl::flat_hash_map<string, AttrValue> body_parameters;
absl::flat_hash_map<InputPort, string> const_inputs;
bool operator==(const FunctionSpecializationSignature& other) const {
bool equals = func_name == other.func_name &&
@ -189,62 +190,59 @@ struct FunctionSpecializationSignature {
return true;
}
// TODO(ezhulenev): Migrate to AbslHashValue.
// TODO(ezhulenev): Optimize performance by computing hashes of unordered
// values first, and then compute a hash of sorted hashes.
struct Hash {
uint64 operator()(FunctionSpecializationSignature const& s) const {
uint64 h = Hash64(s.func_name);
h = Hash64Combine(std::hash<bool>()(s.is_in_fetch_set), h);
template <typename H>
friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
// Use std::set/std::map for deterministic iteration order.
// First pre-compute hashes for all values in collections with
// non-deterministic iteration order.
std::vector<uint64> hashes;
hashes.reserve(s.active_outputs.size() //
+ s.type_parameters.size() * 2 //
+ s.body_parameters.size() * 2 //
+ s.const_inputs.size() * 2);
std::set<OutputPort> active_outputs(s.active_outputs.begin(),
s.active_outputs.end());
for (const auto& active_output : active_outputs) {
h = Hash64Combine(std::hash<int>()(active_output), h);
}
absl::c_transform(s.active_outputs, std::back_inserter(hashes),
hash<OutputPort>());
std::map<string, DataType> types(s.type_parameters.begin(),
s.type_parameters.end());
for (const auto& pair : types) {
using TypeParam = std::pair<const string, DataType>;
absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
AttrValue attr_value;
attr_value.set_type(pair.second);
h = Hash64Combine(Hash64(pair.first), h);
h = Hash64Combine(AttrValueHash(attr_value), h);
}
attr_value.set_type(type_param.second);
hashes.push_back(Hash64(type_param.first));
hashes.push_back(AttrValueHash(attr_value));
});
std::map<string, AttrValue> body(s.body_parameters.begin(),
s.body_parameters.end());
for (const auto& pair : body) {
h = Hash64Combine(Hash64(pair.first), h);
h = Hash64Combine(FastAttrValueHash(pair.second), h);
}
using BodyParam = std::pair<const string, AttrValue>;
absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
hashes.push_back(Hash64(body_param.first));
hashes.push_back(FastAttrValueHash(body_param.second));
});
std::map<InputPort, string> inputs(s.const_inputs.begin(),
s.const_inputs.end());
for (const auto& pair : inputs) {
h = Hash64Combine(std::hash<int>()(pair.first), h);
h = Hash64Combine(Hash64(pair.second), h);
}
using ConstInput = std::pair<const InputPort, string>;
absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
hashes.push_back(hash<InputPort>()(const_input.first));
hashes.push_back(Hash64(const_input.second));
});
return h;
// Combine all pre-computed hashes in a deterministic order.
absl::c_sort(hashes);
return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
}
};
};
struct FunctionSpecialization {
string specialized_func_name;
// True if the function caller node is in GrapplerItem fetch set.
bool is_in_fetch_set;
// Names of the tensors that were pushed down into the function body.
gtl::FlatSet<string> const_inputs;
absl::flat_hash_set<string> const_inputs;
// Control dependencies of pushed down const inputs have to be attached to
// function caller node.
gtl::FlatSet<string> control_deps;
absl::flat_hash_set<string> control_deps;
// Output tensors (ports) that consumed by other nodes in the graph or in a
// GrapplerItem fetch set.
gtl::FlatSet<int> active_outputs;
absl::flat_hash_set<int> active_outputs;
// Mapping from original function output port to the output port of
// specialized function. If function specialization changes the number of
// function outputs it's required to update all node consumers.
@ -285,12 +283,13 @@ class FunctionOptimizerContext {
return flr_;
}
const gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const {
return tensor_mapping_;
}
const gtl::FlatMap<string, std::vector<string>>& control_overrides() const {
const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
const {
return control_overrides_;
}
@ -298,7 +297,9 @@ class FunctionOptimizerContext {
const string& grappler_item_id() const { return grappler_item_id_; }
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
const absl::flat_hash_set<string>& fetch_tensors() const {
return fetch_tensors_;
}
const DeviceSet* devices() const {
// Create fake devices lazily only if we need a DeviceSet.
@ -365,7 +366,7 @@ class FunctionOptimizerContext {
private:
void InitializeTrulyConstNodes(const GrapplerItem& item) {
gtl::FlatSet<string> feed_nodes;
absl::flat_hash_set<string> feed_nodes;
for (const auto& feed : item.feed) {
feed_nodes.insert(NodeName(feed.first));
}
@ -411,7 +412,7 @@ class FunctionOptimizerContext {
FunctionLibraryRuntime* flr_ = nullptr;
// Fully defined names of the devices available to the GrapplerItem.
const gtl::FlatSet<string> available_device_names_;
const absl::flat_hash_set<string> available_device_names_;
// List of available `FakedDevices` (lazily initialized, see devices()).
mutable std::vector<std::unique_ptr<Device>> available_devices_;
@ -421,16 +422,15 @@ class FunctionOptimizerContext {
mutable DeviceSet available_device_set_;
// Nodes that are Const and not in feed.
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
// Specialized functions.
std::unordered_map<FunctionSpecializationSignature,
const FunctionSpecialization,
FunctionSpecializationSignature::Hash>
absl::flat_hash_map<FunctionSpecializationSignature,
const FunctionSpecialization>
specialized_functions_;
// GrapplerItem.fetch is a vector of tensors.
gtl::FlatSet<string> fetch_tensors_; // format: node_name:port
gtl::FlatSet<string> fetch_nodes_; // format: node_name
absl::flat_hash_set<string> fetch_tensors_; // format: node_name:port
absl::flat_hash_set<string> fetch_nodes_; // format: node_name
// After function inlining and specialization, the optimized graph might be in
// invalid state, nodes can read from non-existing function call nodes that
@ -439,7 +439,7 @@ class FunctionOptimizerContext {
//
// Tensor mapping that has to be applied to the graph after all functions
// optimizations (invalidated tensor id -> optimized graph tensor id).
gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
tensor_mapping_;
// When we inline a function into the optimized graph, we no longer have the
@ -448,7 +448,7 @@ class FunctionOptimizerContext {
// to all side-effectful ops inside the function body.
//
// Invalidated function call node name -> Inlined side-effectful nodes
gtl::FlatMap<string, std::vector<string>> control_overrides_;
absl::flat_hash_map<string, std::vector<string>> control_overrides_;
// Use graph view to find active outputs of the function caller nodes.
GraphView graph_view_;
@ -472,10 +472,10 @@ const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
return ctx.function_library().Find(node.op());
}
gtl::FlatSet<int> GetActiveOutputs(const NodeDef& node,
absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
const FunctionOptimizerContext& ctx,
int size_hint = 0) {
gtl::FlatSet<int> active_outputs;
absl::flat_hash_set<int> active_outputs;
active_outputs.reserve(static_cast<size_t>(size_hint));
// 1. Output can be consumed by the other graph node.
@ -508,7 +508,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
// number of output args is the same as number of possible function caller
// node outputs.
int num_outputs = func.signature().output_arg_size();
const gtl::FlatSet<int> active_outputs =
const absl::flat_hash_set<int> active_outputs =
GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
return active_outputs.size() != num_outputs;
@ -519,7 +519,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
const GraphDef& optimized_graph) {
FunctionLibraryDefinition pruned_flib =
ReachableFunctionLibraryDefinition(flib, optimized_graph);
flib.ReachableDefinitions(optimized_graph);
int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
static_cast<int>(flib.num_functions());
@ -534,8 +534,8 @@ FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
Status PushDownConstInputs(const NodeDef& func_node,
const FunctionOptimizerContext& ctx,
GrapplerFunctionItem* item,
gtl::FlatSet<string>* const_inputs,
gtl::FlatSet<string>* control_deps) {
absl::flat_hash_set<string>* const_inputs,
absl::flat_hash_set<string>* control_deps) {
// Record node control dependencies in the control_deps set.
const auto record_control_deps = [&](const NodeDef* const_input) {
for (int i = const_input->input_size() - 1; i >= 0; --i) {
@ -585,7 +585,7 @@ void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
// Attach control dependencies of pushed down const input to the caller node.
if (!specialization.control_deps.empty()) {
gtl::FlatSet<string> existing_control_deps;
absl::flat_hash_set<string> existing_control_deps;
for (const string& input : keep_inputs) {
existing_control_deps.insert(AsControlDependency(NodeName(input)));
@ -797,8 +797,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Push const inputs into the function body, and keep track of their control
// dependencies.
gtl::FlatSet<string> const_inputs;
gtl::FlatSet<string> control_deps;
absl::flat_hash_set<string> const_inputs;
absl::flat_hash_set<string> control_deps;
TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
&control_deps));
@ -1005,7 +1005,7 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
// Mapping from input placeholder name to function input position.
int idx = 0;
std::unordered_map<string, int> input_placeholders_idx;
absl::flat_hash_map<string, int> input_placeholders_idx;
for (const InputArgExpansion& input_arg : item.inputs()) {
for (const string& placeholder : input_arg.placeholders) {
input_placeholders_idx[placeholder] = idx++;
@ -1699,7 +1699,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
if (!ctx.control_overrides().empty()) {
for (NodeDef& node : *optimized_graph->mutable_node()) {
// Keep track of new control inputs to the node.
gtl::FlatSet<string> add_ctrl_inputs;
absl::flat_hash_set<string> add_ctrl_inputs;
// Remove all invalidated control inputs.
for (int idx = 0; idx < node.input_size(); /* see below */) {

View File

@ -427,6 +427,14 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// Constructs a FunctionLibraryDefinition with functions that are reachable
// from the nodes of the graph.
const auto minimized_flib =
[](const GraphDef& graph) -> FunctionLibraryDefinition {
return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
.ReachableDefinitions(graph);
};
// 0. Original graph might contain a huge function library, that is mostly
// unused. This library copied over by each individual Grappler optimizer,
// which adds a huge overhead. Before starting optimization passes we just
@ -436,11 +444,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef trimmed_graph; // do not copy graph with a potentially huge library
*trimmed_graph.mutable_node() = item.graph.node();
*trimmed_graph.mutable_versions() = item.graph.versions();
*trimmed_graph.mutable_library() =
grappler::ReachableFunctionLibraryDefinition(
FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()),
item.graph)
.ToProto();
*trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
@ -472,10 +476,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// 2. Optimize functions reachable from the optimized graph.
FunctionLibraryDefinition flib = ReachableFunctionLibraryDefinition(
FunctionLibraryDefinition(OpRegistry::Global(),
optimized_graph->library()),
*optimized_graph);
FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
// Find functions for which we might need to compute a gradient at runtime.
absl::flat_hash_set<string> differentiable_functions;

View File

@ -178,6 +178,9 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@ -196,6 +199,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@com_google_absl//absl/container:flat_hash_map",
],
)

View File

@ -14,8 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/functions.h"
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
@ -76,16 +76,6 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const GraphDef& graph) {
return flib.ReachableDefinitions(graph);
}
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const FunctionDef& func) {
return flib.ReachableDefinitions(func);
}
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
InputArgExpansion input_arg_expansion) {
string input_name = input_arg_expansion.input_name;
@ -94,7 +84,7 @@ void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
input_arg_placeholders_.insert(
{placeholder, InputArgPlaceholder{input_name, /*input_position=*/i}});
{placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
}
input_arg_expansions_.insert(
{std::move(input_name), std::move(input_arg_expansion)});
@ -193,7 +183,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
graph_def_inputs->push_back(
i == 0 ? node_name : strings::StrCat(node_name, ":", i));
i == 0 ? node_name : absl::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@ -203,7 +193,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
}
int pos = output_range.first + position;
graph_def_inputs->push_back(
pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
}
return Status::OK();
@ -232,39 +222,39 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
Status GrapplerFunctionConnectivity::AsFunctionDefInput(
const string& graph_def_input, string* func_def_input) const {
using gtl::FindOrNull;
if (IsControlInput(graph_def_input)) {
*func_def_input = graph_def_input;
return Status::OK();
}
int position;
string node_name = ParseNodeName(graph_def_input, &position);
CHECK_GE(position, 0);
const TensorId tensor = ParseTensorName(graph_def_input);
DCHECK_GE(tensor.index(), 0);
const absl::string_view node_name = tensor.node();
const int index = tensor.index();
// Check if it's an input arg placeholder
if (position == 0) {
const InputArgPlaceholder* placeholder =
FindOrNull(input_arg_placeholders_, node_name);
if (placeholder != nullptr) {
*func_def_input = strings::StrCat(placeholder->input_name, ":",
placeholder->input_position);
if (tensor.index() == 0) {
const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
if (is_input_placeholder != input_arg_placeholders_.end()) {
const InputArgPlaceholder& placeholder = is_input_placeholder->second;
*func_def_input =
absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
return Status::OK();
}
}
// It must be output from one of the function body nodes
const tensorflow::NameRangeMap* outputs_range_map =
FindOrNull(function_body_outputs_, node_name);
if (outputs_range_map != nullptr) {
for (const auto& el : *outputs_range_map) {
const auto is_body_output = function_body_outputs_.find(tensor.node());
if (is_body_output != function_body_outputs_.end()) {
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
for (const auto& el : outputs_range_map) {
const auto& output_name = el.first;
const auto& output_range = el.second;
if (position >= output_range.first && position < output_range.second) {
int pos = position - output_range.first;
*func_def_input =
strings::StrCat(node_name, ":", output_name, ":", pos);
if (index >= output_range.first && index < output_range.second) {
int pos = index - output_range.first;
*func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
return Status::OK();
}
}
@ -426,7 +416,7 @@ bool IsParametrized(const FunctionDef& func) {
Status InstantiationTypeParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, DataType>* type_parameters) {
absl::flat_hash_map<string, DataType>* type_parameters) {
if (!type_parameters->empty()) {
return errors::InvalidArgument("Type parameters output map must be empty");
}
@ -454,7 +444,7 @@ Status InstantiationTypeParameters(
Status InstantiationBodyParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, AttrValue>* body_parameters) {
absl::flat_hash_map<string, AttrValue>* body_parameters) {
if (!body_parameters->empty()) {
return errors::InvalidArgument("Body parameters output map must be empty");
}
@ -514,8 +504,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// Function body shares the library with the graph that instantiated it. We do
// not need a full copy of the function library, just the reachable subset.
*function_body.mutable_library() =
ReachableFunctionLibraryDefinition(flib, func).ToProto();
*function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
VLOG(3) << absl::Substitute(
"Deleted $0 unreachable functions from the Grappler function item "
@ -645,7 +634,7 @@ Status RegisterGrapplerFunctionConnectivity(
return Status::OK();
}
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item) {
if (!IsConstant(input_const)) {
return errors::InvalidArgument("Input node ", input_const.name(),
@ -657,7 +646,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
// Find input arg expansion and input placeholder position in it for the
// given function input position.
InputArgExpansion* input_arg_expansion = nullptr;
int placeholder_idx = input_position;
int placeholder_idx = input_index;
for (InputArgExpansion& input : inputs) {
if (placeholder_idx < input.placeholders.size()) {
@ -668,9 +657,8 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
}
if (input_arg_expansion == nullptr) {
return errors::InvalidArgument(
"Input placeholder not found: input_position=", input_position,
" function=", item->id);
return errors::InvalidArgument("Input placeholder not found: input_index=",
input_index, " function=", item->id);
}
// Delete placeholder from input expansion.
@ -699,7 +687,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
return Status::OK();
}
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping) {
DCHECK(output_mapping->empty());
@ -713,7 +701,7 @@ Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
}
}
gtl::FlatSet<const OutputArgExpansion*> unused_output_args;
absl::flat_hash_set<const OutputArgExpansion*> unused_output_args;
const auto is_unused_output_arg = [&](const OutputArgExpansion& output) {
return unused_output_args.find(&output) != unused_output_args.end();

View File

@ -18,7 +18,9 @@ limitations under the License.
#include <memory>
#include <string>
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@ -30,13 +32,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Returns a copy of FunctionLibraryDefinition with subset of functions that are
// reachable from the nodes of the graph.
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const GraphDef& graph);
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
const FunctionLibraryDefinition& flib, const FunctionDef& func);
// Depending on the function instantiation attributes, input argument to the
// function might be a single tensor, list of tensors of the same type, or a
// list of tensors of different types.
@ -81,12 +76,12 @@ class GrapplerFunctionConnectivity {
void RegisterFunctionBodyOutputs(const string& node_name,
tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// Expands input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
Status ExpandFunctionDefInput(const string& func_def_input,
std::vector<string>* graph_def_inputs) const;
// Update Node inputs from FunctionDef to GraphDef format.
// Updates Node inputs from FunctionDef to GraphDef format.
Status ExpandNodeInputs(NodeDef* function_body_node) const;
// When expanding inputs in function def format, single input might be
@ -96,29 +91,31 @@ class GrapplerFunctionConnectivity {
// instantiation attributes and length of input args (and node def outputs) is
// known.
// Map from GraphDef input format to FunctionDef input format using registered
// input arg expansion and function body outputs.
// Converts input name from GraphDef format (name[:position]) to the
// FunctionDef input format (name[:output][:position]) using registered input
// arg expansion and function body outputs.
Status AsFunctionDefInput(const string& graph_def_input,
string* func_def_input) const;
// Update Node inputs from GraphDef to FunctionDef format.
// Updates Node inputs from GraphDef to FunctionDef format.
Status AsFunctionDefNode(NodeDef* function_body_node) const;
private:
// Mapping from input name to input arg expansion.
std::unordered_map<string, InputArgExpansion> input_arg_expansions_;
absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
// Mapping from function body node name to output names range map.
std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_;
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
// For each placeholder added to the function instantiation graph, we keep a
// mapping back to the function input argument name and index.
struct InputArgPlaceholder {
string input_name; // Name of the function input argument.
int input_position; // Index of a tensor in the function input argument
int input_index; // Index of a tensor in the function input argument
// expansion, it can be greater than `0` if input
// argument is a list of tensors (aka list(type)).
};
// Mapping from input arg placeholder to the function input tensor.
std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_;
absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
};
// Get Function type attributes using attributes of a node that instantiated
@ -172,7 +169,8 @@ class GrapplerFunctionItem : public GrapplerItem {
friend Status ReplaceInputWithConst(const NodeDef&, int,
GrapplerFunctionItem*);
friend Status RemoveUnusedOutputs(
const gtl::FlatSet<int>& active_outputs, GrapplerFunctionItem* item,
const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping);
GrapplerFunctionItem(string func_name, string description,
@ -191,7 +189,7 @@ class GrapplerFunctionItem : public GrapplerItem {
std::set<string> input_arg_placeholders_;
bool is_stateful_;
bool is_stateful_ = false;
};
// Check if function input/output types are fully defined only at instantiation
@ -210,14 +208,14 @@ bool IsParametrized(const FunctionDef& func);
// caller node. Return error if type can't be resolved.
Status InstantiationTypeParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, DataType>* type_parameters);
absl::flat_hash_map<string, DataType>* type_parameters);
// Resolve function instantiation body parameters (values for the function body
// attr placeholders) from the attributes of the caller node. Return error if
// type can't be resolved.
Status InstantiationBodyParameters(
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
std::unordered_map<string, AttrValue>* body_parameters);
absl::flat_hash_map<string, AttrValue>* body_parameters);
// Register GrapplerFunctionItem input arg expansion and function body outputs
// in the GrapplerFunctionConnectivity. Use function library definition to
@ -227,7 +225,7 @@ Status RegisterGrapplerFunctionConnectivity(
GrapplerFunctionConnectivity* connectivity);
// Replace one of the function inputs with a constant.
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
GrapplerFunctionItem* item);
// Remove function output arguments that do not have any active outputs (output
@ -236,7 +234,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
// potentially be connected to the same output argument (in case of tensor list
// outputs). Add output mapping for all active outputs that changed it's output
// position (std::pair<old position, new position>).
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
GrapplerFunctionItem* item,
std::vector<std::pair<int, int>>* output_mapping);

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/functions.h"
#include "absl/container/flat_hash_map.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -77,7 +79,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
func_instantiation_attr["B"].set_type(DT_INT32);
func_instantiation_attr["C"].set_type(DT_DOUBLE);
std::unordered_map<string, DataType> type_parameters;
absl::flat_hash_map<string, DataType> type_parameters;
TF_EXPECT_OK(InstantiationTypeParameters(
func, AttrSlice(&func_instantiation_attr), &type_parameters));
@ -86,7 +88,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
EXPECT_EQ(DT_INT32, type_parameters["B"]);
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
std::unordered_map<string, AttrValue> body_parameters;
absl::flat_hash_map<string, AttrValue> body_parameters;
TF_EXPECT_OK(InstantiationBodyParameters(
func, AttrSlice(&func_instantiation_attr), &body_parameters));