[Grappler] Cleanup function/function_optimizer
1. Migrate to absl containers 2. Compute function signature hash without allocating temporary sorted containers PiperOrigin-RevId: 227587343
This commit is contained in:
parent
790a635a04
commit
a7a3bbfcbf
@ -151,6 +151,8 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/utils:functions",
|
"//tensorflow/core/grappler/utils:functions",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
|
@ -15,10 +15,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_replace.h"
|
#include "absl/strings/str_replace.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
@ -163,10 +164,10 @@ struct FunctionSpecializationSignature {
|
|||||||
|
|
||||||
string func_name;
|
string func_name;
|
||||||
bool is_in_fetch_set;
|
bool is_in_fetch_set;
|
||||||
gtl::FlatSet<OutputPort> active_outputs;
|
absl::flat_hash_set<OutputPort> active_outputs;
|
||||||
std::unordered_map<string, DataType> type_parameters;
|
absl::flat_hash_map<string, DataType> type_parameters;
|
||||||
std::unordered_map<string, AttrValue> body_parameters;
|
absl::flat_hash_map<string, AttrValue> body_parameters;
|
||||||
std::unordered_map<InputPort, string> const_inputs;
|
absl::flat_hash_map<InputPort, string> const_inputs;
|
||||||
|
|
||||||
bool operator==(const FunctionSpecializationSignature& other) const {
|
bool operator==(const FunctionSpecializationSignature& other) const {
|
||||||
bool equals = func_name == other.func_name &&
|
bool equals = func_name == other.func_name &&
|
||||||
@ -189,48 +190,45 @@ struct FunctionSpecializationSignature {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(ezhulenev): Migrate to AbslHashValue.
|
template <typename H>
|
||||||
// TODO(ezhulenev): Optimize performance by computing hashes of unordered
|
friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
|
||||||
// values first, and then compute a hash of sorted hashes.
|
H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
|
||||||
struct Hash {
|
|
||||||
uint64 operator()(FunctionSpecializationSignature const& s) const {
|
|
||||||
uint64 h = Hash64(s.func_name);
|
|
||||||
h = Hash64Combine(std::hash<bool>()(s.is_in_fetch_set), h);
|
|
||||||
|
|
||||||
// Use std::set/std::map for deterministic iteration order.
|
// First pre-compute hashes for all values in collections with
|
||||||
|
// non-deterministic iteration order.
|
||||||
|
std::vector<uint64> hashes;
|
||||||
|
hashes.reserve(s.active_outputs.size() //
|
||||||
|
+ s.type_parameters.size() * 2 //
|
||||||
|
+ s.body_parameters.size() * 2 //
|
||||||
|
+ s.const_inputs.size() * 2);
|
||||||
|
|
||||||
std::set<OutputPort> active_outputs(s.active_outputs.begin(),
|
absl::c_transform(s.active_outputs, std::back_inserter(hashes),
|
||||||
s.active_outputs.end());
|
hash<OutputPort>());
|
||||||
for (const auto& active_output : active_outputs) {
|
|
||||||
h = Hash64Combine(std::hash<int>()(active_output), h);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<string, DataType> types(s.type_parameters.begin(),
|
using TypeParam = std::pair<const string, DataType>;
|
||||||
s.type_parameters.end());
|
absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
|
||||||
for (const auto& pair : types) {
|
AttrValue attr_value;
|
||||||
AttrValue attr_value;
|
attr_value.set_type(type_param.second);
|
||||||
attr_value.set_type(pair.second);
|
hashes.push_back(Hash64(type_param.first));
|
||||||
h = Hash64Combine(Hash64(pair.first), h);
|
hashes.push_back(AttrValueHash(attr_value));
|
||||||
h = Hash64Combine(AttrValueHash(attr_value), h);
|
});
|
||||||
}
|
|
||||||
|
|
||||||
std::map<string, AttrValue> body(s.body_parameters.begin(),
|
using BodyParam = std::pair<const string, AttrValue>;
|
||||||
s.body_parameters.end());
|
absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
|
||||||
for (const auto& pair : body) {
|
hashes.push_back(Hash64(body_param.first));
|
||||||
h = Hash64Combine(Hash64(pair.first), h);
|
hashes.push_back(FastAttrValueHash(body_param.second));
|
||||||
h = Hash64Combine(FastAttrValueHash(pair.second), h);
|
});
|
||||||
}
|
|
||||||
|
|
||||||
std::map<InputPort, string> inputs(s.const_inputs.begin(),
|
using ConstInput = std::pair<const InputPort, string>;
|
||||||
s.const_inputs.end());
|
absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
|
||||||
for (const auto& pair : inputs) {
|
hashes.push_back(hash<InputPort>()(const_input.first));
|
||||||
h = Hash64Combine(std::hash<int>()(pair.first), h);
|
hashes.push_back(Hash64(const_input.second));
|
||||||
h = Hash64Combine(Hash64(pair.second), h);
|
});
|
||||||
}
|
|
||||||
|
|
||||||
return h;
|
// Combine all pre-computed hashes in a deterministic order.
|
||||||
}
|
absl::c_sort(hashes);
|
||||||
};
|
return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FunctionSpecialization {
|
struct FunctionSpecialization {
|
||||||
@ -238,13 +236,13 @@ struct FunctionSpecialization {
|
|||||||
// True if the function caller node is in GrapplerItem fetch set.
|
// True if the function caller node is in GrapplerItem fetch set.
|
||||||
bool is_in_fetch_set;
|
bool is_in_fetch_set;
|
||||||
// Names of the tensors that were pushed down into the function body.
|
// Names of the tensors that were pushed down into the function body.
|
||||||
gtl::FlatSet<string> const_inputs;
|
absl::flat_hash_set<string> const_inputs;
|
||||||
// Control dependencies of pushed down const inputs have to be attached to
|
// Control dependencies of pushed down const inputs have to be attached to
|
||||||
// function caller node.
|
// function caller node.
|
||||||
gtl::FlatSet<string> control_deps;
|
absl::flat_hash_set<string> control_deps;
|
||||||
// Output tensors (ports) that consumed by other nodes in the graph or in a
|
// Output tensors (ports) that consumed by other nodes in the graph or in a
|
||||||
// GrapplerItem fetch set.
|
// GrapplerItem fetch set.
|
||||||
gtl::FlatSet<int> active_outputs;
|
absl::flat_hash_set<int> active_outputs;
|
||||||
// Mapping from original function output port to the output port of
|
// Mapping from original function output port to the output port of
|
||||||
// specialized function. If function specialization changes the number of
|
// specialized function. If function specialization changes the number of
|
||||||
// function outputs it's required to update all node consumers.
|
// function outputs it's required to update all node consumers.
|
||||||
@ -285,12 +283,13 @@ class FunctionOptimizerContext {
|
|||||||
return flr_;
|
return flr_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
|
const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
|
||||||
tensor_mapping() const {
|
tensor_mapping() const {
|
||||||
return tensor_mapping_;
|
return tensor_mapping_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const gtl::FlatMap<string, std::vector<string>>& control_overrides() const {
|
const absl::flat_hash_map<string, std::vector<string>>& control_overrides()
|
||||||
|
const {
|
||||||
return control_overrides_;
|
return control_overrides_;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -298,7 +297,9 @@ class FunctionOptimizerContext {
|
|||||||
|
|
||||||
const string& grappler_item_id() const { return grappler_item_id_; }
|
const string& grappler_item_id() const { return grappler_item_id_; }
|
||||||
|
|
||||||
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
|
const absl::flat_hash_set<string>& fetch_tensors() const {
|
||||||
|
return fetch_tensors_;
|
||||||
|
}
|
||||||
|
|
||||||
const DeviceSet* devices() const {
|
const DeviceSet* devices() const {
|
||||||
// Create fake devices lazily only if we need a DeviceSet.
|
// Create fake devices lazily only if we need a DeviceSet.
|
||||||
@ -365,7 +366,7 @@ class FunctionOptimizerContext {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void InitializeTrulyConstNodes(const GrapplerItem& item) {
|
void InitializeTrulyConstNodes(const GrapplerItem& item) {
|
||||||
gtl::FlatSet<string> feed_nodes;
|
absl::flat_hash_set<string> feed_nodes;
|
||||||
for (const auto& feed : item.feed) {
|
for (const auto& feed : item.feed) {
|
||||||
feed_nodes.insert(NodeName(feed.first));
|
feed_nodes.insert(NodeName(feed.first));
|
||||||
}
|
}
|
||||||
@ -411,7 +412,7 @@ class FunctionOptimizerContext {
|
|||||||
FunctionLibraryRuntime* flr_ = nullptr;
|
FunctionLibraryRuntime* flr_ = nullptr;
|
||||||
|
|
||||||
// Fully defined names of the devices available to the GrapplerItem.
|
// Fully defined names of the devices available to the GrapplerItem.
|
||||||
const gtl::FlatSet<string> available_device_names_;
|
const absl::flat_hash_set<string> available_device_names_;
|
||||||
|
|
||||||
// List of available `FakedDevices` (lazily initialized, see devices()).
|
// List of available `FakedDevices` (lazily initialized, see devices()).
|
||||||
mutable std::vector<std::unique_ptr<Device>> available_devices_;
|
mutable std::vector<std::unique_ptr<Device>> available_devices_;
|
||||||
@ -421,16 +422,15 @@ class FunctionOptimizerContext {
|
|||||||
mutable DeviceSet available_device_set_;
|
mutable DeviceSet available_device_set_;
|
||||||
|
|
||||||
// Nodes that are Const and not in feed.
|
// Nodes that are Const and not in feed.
|
||||||
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
|
absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
|
||||||
// Specialized functions.
|
// Specialized functions.
|
||||||
std::unordered_map<FunctionSpecializationSignature,
|
absl::flat_hash_map<FunctionSpecializationSignature,
|
||||||
const FunctionSpecialization,
|
const FunctionSpecialization>
|
||||||
FunctionSpecializationSignature::Hash>
|
|
||||||
specialized_functions_;
|
specialized_functions_;
|
||||||
|
|
||||||
// GrapplerItem.fetch is a vector of tensors.
|
// GrapplerItem.fetch is a vector of tensors.
|
||||||
gtl::FlatSet<string> fetch_tensors_; // format: node_name:port
|
absl::flat_hash_set<string> fetch_tensors_; // format: node_name:port
|
||||||
gtl::FlatSet<string> fetch_nodes_; // format: node_name
|
absl::flat_hash_set<string> fetch_nodes_; // format: node_name
|
||||||
|
|
||||||
// After function inlining and specialization, the optimized graph might be in
|
// After function inlining and specialization, the optimized graph might be in
|
||||||
// invalid state, nodes can read from non-existing function call nodes that
|
// invalid state, nodes can read from non-existing function call nodes that
|
||||||
@ -439,7 +439,7 @@ class FunctionOptimizerContext {
|
|||||||
//
|
//
|
||||||
// Tensor mapping that has to be applied to the graph after all functions
|
// Tensor mapping that has to be applied to the graph after all functions
|
||||||
// optimizations (invalidated tensor id -> optimized graph tensor id).
|
// optimizations (invalidated tensor id -> optimized graph tensor id).
|
||||||
gtl::FlatMap<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
|
absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
|
||||||
tensor_mapping_;
|
tensor_mapping_;
|
||||||
|
|
||||||
// When we inline a function into the optimized graph, we no longer have the
|
// When we inline a function into the optimized graph, we no longer have the
|
||||||
@ -448,7 +448,7 @@ class FunctionOptimizerContext {
|
|||||||
// to all side-effectful ops inside the function body.
|
// to all side-effectful ops inside the function body.
|
||||||
//
|
//
|
||||||
// Invalidated function call node name -> Inlined side-effectful nodes
|
// Invalidated function call node name -> Inlined side-effectful nodes
|
||||||
gtl::FlatMap<string, std::vector<string>> control_overrides_;
|
absl::flat_hash_map<string, std::vector<string>> control_overrides_;
|
||||||
|
|
||||||
// Use graph view to find active outputs of the function caller nodes.
|
// Use graph view to find active outputs of the function caller nodes.
|
||||||
GraphView graph_view_;
|
GraphView graph_view_;
|
||||||
@ -472,10 +472,10 @@ const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
|
|||||||
return ctx.function_library().Find(node.op());
|
return ctx.function_library().Find(node.op());
|
||||||
}
|
}
|
||||||
|
|
||||||
gtl::FlatSet<int> GetActiveOutputs(const NodeDef& node,
|
absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
|
||||||
const FunctionOptimizerContext& ctx,
|
const FunctionOptimizerContext& ctx,
|
||||||
int size_hint = 0) {
|
int size_hint = 0) {
|
||||||
gtl::FlatSet<int> active_outputs;
|
absl::flat_hash_set<int> active_outputs;
|
||||||
active_outputs.reserve(static_cast<size_t>(size_hint));
|
active_outputs.reserve(static_cast<size_t>(size_hint));
|
||||||
|
|
||||||
// 1. Output can be consumed by the other graph node.
|
// 1. Output can be consumed by the other graph node.
|
||||||
@ -508,7 +508,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
|
|||||||
// number of output args is the same as number of possible function caller
|
// number of output args is the same as number of possible function caller
|
||||||
// node outputs.
|
// node outputs.
|
||||||
int num_outputs = func.signature().output_arg_size();
|
int num_outputs = func.signature().output_arg_size();
|
||||||
const gtl::FlatSet<int> active_outputs =
|
const absl::flat_hash_set<int> active_outputs =
|
||||||
GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
|
GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
|
||||||
|
|
||||||
return active_outputs.size() != num_outputs;
|
return active_outputs.size() != num_outputs;
|
||||||
@ -519,7 +519,7 @@ bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
|
|||||||
FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
|
FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
|
||||||
const GraphDef& optimized_graph) {
|
const GraphDef& optimized_graph) {
|
||||||
FunctionLibraryDefinition pruned_flib =
|
FunctionLibraryDefinition pruned_flib =
|
||||||
ReachableFunctionLibraryDefinition(flib, optimized_graph);
|
flib.ReachableDefinitions(optimized_graph);
|
||||||
|
|
||||||
int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
|
int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
|
||||||
static_cast<int>(flib.num_functions());
|
static_cast<int>(flib.num_functions());
|
||||||
@ -534,8 +534,8 @@ FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
|
|||||||
Status PushDownConstInputs(const NodeDef& func_node,
|
Status PushDownConstInputs(const NodeDef& func_node,
|
||||||
const FunctionOptimizerContext& ctx,
|
const FunctionOptimizerContext& ctx,
|
||||||
GrapplerFunctionItem* item,
|
GrapplerFunctionItem* item,
|
||||||
gtl::FlatSet<string>* const_inputs,
|
absl::flat_hash_set<string>* const_inputs,
|
||||||
gtl::FlatSet<string>* control_deps) {
|
absl::flat_hash_set<string>* control_deps) {
|
||||||
// Record node control dependencies in the control_deps set.
|
// Record node control dependencies in the control_deps set.
|
||||||
const auto record_control_deps = [&](const NodeDef* const_input) {
|
const auto record_control_deps = [&](const NodeDef* const_input) {
|
||||||
for (int i = const_input->input_size() - 1; i >= 0; --i) {
|
for (int i = const_input->input_size() - 1; i >= 0; --i) {
|
||||||
@ -585,7 +585,7 @@ void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
|
|||||||
|
|
||||||
// Attach control dependencies of pushed down const input to the caller node.
|
// Attach control dependencies of pushed down const input to the caller node.
|
||||||
if (!specialization.control_deps.empty()) {
|
if (!specialization.control_deps.empty()) {
|
||||||
gtl::FlatSet<string> existing_control_deps;
|
absl::flat_hash_set<string> existing_control_deps;
|
||||||
|
|
||||||
for (const string& input : keep_inputs) {
|
for (const string& input : keep_inputs) {
|
||||||
existing_control_deps.insert(AsControlDependency(NodeName(input)));
|
existing_control_deps.insert(AsControlDependency(NodeName(input)));
|
||||||
@ -797,8 +797,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
|
|||||||
|
|
||||||
// Push const inputs into the function body, and keep track of their control
|
// Push const inputs into the function body, and keep track of their control
|
||||||
// dependencies.
|
// dependencies.
|
||||||
gtl::FlatSet<string> const_inputs;
|
absl::flat_hash_set<string> const_inputs;
|
||||||
gtl::FlatSet<string> control_deps;
|
absl::flat_hash_set<string> control_deps;
|
||||||
TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
|
TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
|
||||||
&control_deps));
|
&control_deps));
|
||||||
|
|
||||||
@ -1005,7 +1005,7 @@ Status InlineDirectFunctionCall(const NodeDef& func_node,
|
|||||||
|
|
||||||
// Mapping from input placeholder name to function input position.
|
// Mapping from input placeholder name to function input position.
|
||||||
int idx = 0;
|
int idx = 0;
|
||||||
std::unordered_map<string, int> input_placeholders_idx;
|
absl::flat_hash_map<string, int> input_placeholders_idx;
|
||||||
for (const InputArgExpansion& input_arg : item.inputs()) {
|
for (const InputArgExpansion& input_arg : item.inputs()) {
|
||||||
for (const string& placeholder : input_arg.placeholders) {
|
for (const string& placeholder : input_arg.placeholders) {
|
||||||
input_placeholders_idx[placeholder] = idx++;
|
input_placeholders_idx[placeholder] = idx++;
|
||||||
@ -1699,7 +1699,7 @@ Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
|
|||||||
if (!ctx.control_overrides().empty()) {
|
if (!ctx.control_overrides().empty()) {
|
||||||
for (NodeDef& node : *optimized_graph->mutable_node()) {
|
for (NodeDef& node : *optimized_graph->mutable_node()) {
|
||||||
// Keep track of new control inputs to the node.
|
// Keep track of new control inputs to the node.
|
||||||
gtl::FlatSet<string> add_ctrl_inputs;
|
absl::flat_hash_set<string> add_ctrl_inputs;
|
||||||
|
|
||||||
// Remove all invalidated control inputs.
|
// Remove all invalidated control inputs.
|
||||||
for (int idx = 0; idx < node.input_size(); /* see below */) {
|
for (int idx = 0; idx < node.input_size(); /* see below */) {
|
||||||
|
@ -427,6 +427,14 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
VLOG(1) << "Starting optimization for grappler item: " << item.id;
|
VLOG(1) << "Starting optimization for grappler item: " << item.id;
|
||||||
optimization_results_.clear();
|
optimization_results_.clear();
|
||||||
|
|
||||||
|
// Constructs a FunctionLibraryDefinition with functions that are reachable
|
||||||
|
// from the nodes of the graph.
|
||||||
|
const auto minimized_flib =
|
||||||
|
[](const GraphDef& graph) -> FunctionLibraryDefinition {
|
||||||
|
return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
|
||||||
|
.ReachableDefinitions(graph);
|
||||||
|
};
|
||||||
|
|
||||||
// 0. Original graph might contain a huge function library, that is mostly
|
// 0. Original graph might contain a huge function library, that is mostly
|
||||||
// unused. This library copied over by each individual Grappler optimizer,
|
// unused. This library copied over by each individual Grappler optimizer,
|
||||||
// which adds a huge overhead. Before starting optimization passes we just
|
// which adds a huge overhead. Before starting optimization passes we just
|
||||||
@ -436,11 +444,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
GraphDef trimmed_graph; // do not copy graph with a potentially huge library
|
GraphDef trimmed_graph; // do not copy graph with a potentially huge library
|
||||||
*trimmed_graph.mutable_node() = item.graph.node();
|
*trimmed_graph.mutable_node() = item.graph.node();
|
||||||
*trimmed_graph.mutable_versions() = item.graph.versions();
|
*trimmed_graph.mutable_versions() = item.graph.versions();
|
||||||
*trimmed_graph.mutable_library() =
|
*trimmed_graph.mutable_library() = minimized_flib(item.graph).ToProto();
|
||||||
grappler::ReachableFunctionLibraryDefinition(
|
|
||||||
FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()),
|
|
||||||
item.graph)
|
|
||||||
.ToProto();
|
|
||||||
|
|
||||||
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
|
GrapplerItem trimmed_item = item.WithGraph(std::move(trimmed_graph));
|
||||||
|
|
||||||
@ -472,10 +476,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 2. Optimize functions reachable from the optimized graph.
|
// 2. Optimize functions reachable from the optimized graph.
|
||||||
FunctionLibraryDefinition flib = ReachableFunctionLibraryDefinition(
|
FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
|
||||||
FunctionLibraryDefinition(OpRegistry::Global(),
|
|
||||||
optimized_graph->library()),
|
|
||||||
*optimized_graph);
|
|
||||||
|
|
||||||
// Find functions for which we might need to compute a gradient at runtime.
|
// Find functions for which we might need to compute a gradient at runtime.
|
||||||
absl::flat_hash_set<string> differentiable_functions;
|
absl::flat_hash_set<string> differentiable_functions;
|
||||||
|
@ -178,6 +178,9 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -196,6 +199,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,8 +14,9 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
|
|
||||||
#include <unordered_map>
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/substitute.h"
|
#include "absl/strings/substitute.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
@ -28,7 +29,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/versions.pb.h"
|
#include "tensorflow/core/framework/versions.pb.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
|
||||||
#include "tensorflow/core/lib/strings/scanner.h"
|
#include "tensorflow/core/lib/strings/scanner.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -76,16 +76,6 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|
||||||
const FunctionLibraryDefinition& flib, const GraphDef& graph) {
|
|
||||||
return flib.ReachableDefinitions(graph);
|
|
||||||
}
|
|
||||||
|
|
||||||
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|
||||||
const FunctionLibraryDefinition& flib, const FunctionDef& func) {
|
|
||||||
return flib.ReachableDefinitions(func);
|
|
||||||
}
|
|
||||||
|
|
||||||
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
|
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
|
||||||
InputArgExpansion input_arg_expansion) {
|
InputArgExpansion input_arg_expansion) {
|
||||||
string input_name = input_arg_expansion.input_name;
|
string input_name = input_arg_expansion.input_name;
|
||||||
@ -94,7 +84,7 @@ void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
|
|||||||
for (int i = 0; i < placeholders.size(); ++i) {
|
for (int i = 0; i < placeholders.size(); ++i) {
|
||||||
const string& placeholder = input_arg_expansion.placeholders[i];
|
const string& placeholder = input_arg_expansion.placeholders[i];
|
||||||
input_arg_placeholders_.insert(
|
input_arg_placeholders_.insert(
|
||||||
{placeholder, InputArgPlaceholder{input_name, /*input_position=*/i}});
|
{placeholder, InputArgPlaceholder{input_name, /*input_index=*/i}});
|
||||||
}
|
}
|
||||||
input_arg_expansions_.insert(
|
input_arg_expansions_.insert(
|
||||||
{std::move(input_name), std::move(input_arg_expansion)});
|
{std::move(input_name), std::move(input_arg_expansion)});
|
||||||
@ -193,7 +183,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
|
|||||||
// If position is not defined expand node output range
|
// If position is not defined expand node output range
|
||||||
for (int i = output_range.first; i < output_range.second; ++i) {
|
for (int i = output_range.first; i < output_range.second; ++i) {
|
||||||
graph_def_inputs->push_back(
|
graph_def_inputs->push_back(
|
||||||
i == 0 ? node_name : strings::StrCat(node_name, ":", i));
|
i == 0 ? node_name : absl::StrCat(node_name, ":", i));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (position > (output_range.second - output_range.first)) {
|
if (position > (output_range.second - output_range.first)) {
|
||||||
@ -203,7 +193,7 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
|
|||||||
}
|
}
|
||||||
int pos = output_range.first + position;
|
int pos = output_range.first + position;
|
||||||
graph_def_inputs->push_back(
|
graph_def_inputs->push_back(
|
||||||
pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
|
pos == 0 ? node_name : absl::StrCat(node_name, ":", pos));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -232,39 +222,39 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
|
|||||||
|
|
||||||
Status GrapplerFunctionConnectivity::AsFunctionDefInput(
|
Status GrapplerFunctionConnectivity::AsFunctionDefInput(
|
||||||
const string& graph_def_input, string* func_def_input) const {
|
const string& graph_def_input, string* func_def_input) const {
|
||||||
using gtl::FindOrNull;
|
|
||||||
|
|
||||||
if (IsControlInput(graph_def_input)) {
|
if (IsControlInput(graph_def_input)) {
|
||||||
*func_def_input = graph_def_input;
|
*func_def_input = graph_def_input;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
int position;
|
const TensorId tensor = ParseTensorName(graph_def_input);
|
||||||
string node_name = ParseNodeName(graph_def_input, &position);
|
DCHECK_GE(tensor.index(), 0);
|
||||||
CHECK_GE(position, 0);
|
|
||||||
|
const absl::string_view node_name = tensor.node();
|
||||||
|
const int index = tensor.index();
|
||||||
|
|
||||||
// Check if it's an input arg placeholder
|
// Check if it's an input arg placeholder
|
||||||
if (position == 0) {
|
if (tensor.index() == 0) {
|
||||||
const InputArgPlaceholder* placeholder =
|
const auto is_input_placeholder = input_arg_placeholders_.find(node_name);
|
||||||
FindOrNull(input_arg_placeholders_, node_name);
|
if (is_input_placeholder != input_arg_placeholders_.end()) {
|
||||||
if (placeholder != nullptr) {
|
const InputArgPlaceholder& placeholder = is_input_placeholder->second;
|
||||||
*func_def_input = strings::StrCat(placeholder->input_name, ":",
|
*func_def_input =
|
||||||
placeholder->input_position);
|
absl::StrCat(placeholder.input_name, ":", placeholder.input_index);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// It must be output from one of the function body nodes
|
// It must be output from one of the function body nodes
|
||||||
const tensorflow::NameRangeMap* outputs_range_map =
|
const auto is_body_output = function_body_outputs_.find(tensor.node());
|
||||||
FindOrNull(function_body_outputs_, node_name);
|
if (is_body_output != function_body_outputs_.end()) {
|
||||||
if (outputs_range_map != nullptr) {
|
const tensorflow::NameRangeMap& outputs_range_map = is_body_output->second;
|
||||||
for (const auto& el : *outputs_range_map) {
|
|
||||||
|
for (const auto& el : outputs_range_map) {
|
||||||
const auto& output_name = el.first;
|
const auto& output_name = el.first;
|
||||||
const auto& output_range = el.second;
|
const auto& output_range = el.second;
|
||||||
if (position >= output_range.first && position < output_range.second) {
|
if (index >= output_range.first && index < output_range.second) {
|
||||||
int pos = position - output_range.first;
|
int pos = index - output_range.first;
|
||||||
*func_def_input =
|
*func_def_input = absl::StrCat(node_name, ":", output_name, ":", pos);
|
||||||
strings::StrCat(node_name, ":", output_name, ":", pos);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -426,7 +416,7 @@ bool IsParametrized(const FunctionDef& func) {
|
|||||||
|
|
||||||
Status InstantiationTypeParameters(
|
Status InstantiationTypeParameters(
|
||||||
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
||||||
std::unordered_map<string, DataType>* type_parameters) {
|
absl::flat_hash_map<string, DataType>* type_parameters) {
|
||||||
if (!type_parameters->empty()) {
|
if (!type_parameters->empty()) {
|
||||||
return errors::InvalidArgument("Type parameters output map must be empty");
|
return errors::InvalidArgument("Type parameters output map must be empty");
|
||||||
}
|
}
|
||||||
@ -454,7 +444,7 @@ Status InstantiationTypeParameters(
|
|||||||
|
|
||||||
Status InstantiationBodyParameters(
|
Status InstantiationBodyParameters(
|
||||||
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
||||||
std::unordered_map<string, AttrValue>* body_parameters) {
|
absl::flat_hash_map<string, AttrValue>* body_parameters) {
|
||||||
if (!body_parameters->empty()) {
|
if (!body_parameters->empty()) {
|
||||||
return errors::InvalidArgument("Body parameters output map must be empty");
|
return errors::InvalidArgument("Body parameters output map must be empty");
|
||||||
}
|
}
|
||||||
@ -514,8 +504,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
|
|||||||
|
|
||||||
// Function body shares the library with the graph that instantiated it. We do
|
// Function body shares the library with the graph that instantiated it. We do
|
||||||
// not need a full copy of the function library, just the reachable subset.
|
// not need a full copy of the function library, just the reachable subset.
|
||||||
*function_body.mutable_library() =
|
*function_body.mutable_library() = flib.ReachableDefinitions(func).ToProto();
|
||||||
ReachableFunctionLibraryDefinition(flib, func).ToProto();
|
|
||||||
|
|
||||||
VLOG(3) << absl::Substitute(
|
VLOG(3) << absl::Substitute(
|
||||||
"Deleted $0 unreachable functions from the Grappler function item "
|
"Deleted $0 unreachable functions from the Grappler function item "
|
||||||
@ -645,7 +634,7 @@ Status RegisterGrapplerFunctionConnectivity(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
|
||||||
GrapplerFunctionItem* item) {
|
GrapplerFunctionItem* item) {
|
||||||
if (!IsConstant(input_const)) {
|
if (!IsConstant(input_const)) {
|
||||||
return errors::InvalidArgument("Input node ", input_const.name(),
|
return errors::InvalidArgument("Input node ", input_const.name(),
|
||||||
@ -657,7 +646,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
|||||||
// Find input arg expansion and input placeholder position in it for the
|
// Find input arg expansion and input placeholder position in it for the
|
||||||
// given function input position.
|
// given function input position.
|
||||||
InputArgExpansion* input_arg_expansion = nullptr;
|
InputArgExpansion* input_arg_expansion = nullptr;
|
||||||
int placeholder_idx = input_position;
|
int placeholder_idx = input_index;
|
||||||
|
|
||||||
for (InputArgExpansion& input : inputs) {
|
for (InputArgExpansion& input : inputs) {
|
||||||
if (placeholder_idx < input.placeholders.size()) {
|
if (placeholder_idx < input.placeholders.size()) {
|
||||||
@ -668,9 +657,8 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (input_arg_expansion == nullptr) {
|
if (input_arg_expansion == nullptr) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument("Input placeholder not found: input_index=",
|
||||||
"Input placeholder not found: input_position=", input_position,
|
input_index, " function=", item->id);
|
||||||
" function=", item->id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete placeholder from input expansion.
|
// Delete placeholder from input expansion.
|
||||||
@ -699,7 +687,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
|
Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
|
||||||
GrapplerFunctionItem* item,
|
GrapplerFunctionItem* item,
|
||||||
std::vector<std::pair<int, int>>* output_mapping) {
|
std::vector<std::pair<int, int>>* output_mapping) {
|
||||||
DCHECK(output_mapping->empty());
|
DCHECK(output_mapping->empty());
|
||||||
@ -713,7 +701,7 @@ Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
gtl::FlatSet<const OutputArgExpansion*> unused_output_args;
|
absl::flat_hash_set<const OutputArgExpansion*> unused_output_args;
|
||||||
|
|
||||||
const auto is_unused_output_arg = [&](const OutputArgExpansion& output) {
|
const auto is_unused_output_arg = [&](const OutputArgExpansion& output) {
|
||||||
return unused_output_args.find(&output) != unused_output_args.end();
|
return unused_output_args.find(&output) != unused_output_args.end();
|
||||||
|
@ -18,7 +18,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/function.pb.h"
|
#include "tensorflow/core/framework/function.pb.h"
|
||||||
@ -30,13 +32,6 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace grappler {
|
namespace grappler {
|
||||||
|
|
||||||
// Returns a copy of FunctionLibraryDefinition with subset of functions that are
|
|
||||||
// reachable from the nodes of the graph.
|
|
||||||
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|
||||||
const FunctionLibraryDefinition& flib, const GraphDef& graph);
|
|
||||||
FunctionLibraryDefinition ReachableFunctionLibraryDefinition(
|
|
||||||
const FunctionLibraryDefinition& flib, const FunctionDef& func);
|
|
||||||
|
|
||||||
// Depending on the function instantiation attributes, input argument to the
|
// Depending on the function instantiation attributes, input argument to the
|
||||||
// function might be a single tensor, list of tensors of the same type, or a
|
// function might be a single tensor, list of tensors of the same type, or a
|
||||||
// list of tensors of different types.
|
// list of tensors of different types.
|
||||||
@ -81,12 +76,12 @@ class GrapplerFunctionConnectivity {
|
|||||||
void RegisterFunctionBodyOutputs(const string& node_name,
|
void RegisterFunctionBodyOutputs(const string& node_name,
|
||||||
tensorflow::NameRangeMap&& outputs);
|
tensorflow::NameRangeMap&& outputs);
|
||||||
|
|
||||||
// Expand input encoded in FunctionDef format (name[:output][:position]) into
|
// Expands input encoded in FunctionDef format (name[:output][:position]) into
|
||||||
// multiple inputs in GraphDef format (name[:position]).
|
// multiple inputs in GraphDef format (name[:position]).
|
||||||
Status ExpandFunctionDefInput(const string& func_def_input,
|
Status ExpandFunctionDefInput(const string& func_def_input,
|
||||||
std::vector<string>* graph_def_inputs) const;
|
std::vector<string>* graph_def_inputs) const;
|
||||||
|
|
||||||
// Update Node inputs from FunctionDef to GraphDef format.
|
// Updates Node inputs from FunctionDef to GraphDef format.
|
||||||
Status ExpandNodeInputs(NodeDef* function_body_node) const;
|
Status ExpandNodeInputs(NodeDef* function_body_node) const;
|
||||||
|
|
||||||
// When expanding inputs in function def format, single input might be
|
// When expanding inputs in function def format, single input might be
|
||||||
@ -96,29 +91,31 @@ class GrapplerFunctionConnectivity {
|
|||||||
// instantiation attributes and length of input args (and node def outputs) is
|
// instantiation attributes and length of input args (and node def outputs) is
|
||||||
// known.
|
// known.
|
||||||
|
|
||||||
// Map from GraphDef input format to FunctionDef input format using registered
|
// Converts input name from GraphDef format (name[:position]) to the
|
||||||
// input arg expansion and function body outputs.
|
// FunctionDef input format (name[:output][:position]) using registered input
|
||||||
|
// arg expansion and function body outputs.
|
||||||
Status AsFunctionDefInput(const string& graph_def_input,
|
Status AsFunctionDefInput(const string& graph_def_input,
|
||||||
string* func_def_input) const;
|
string* func_def_input) const;
|
||||||
|
|
||||||
// Update Node inputs from GraphDef to FunctionDef format.
|
// Updates Node inputs from GraphDef to FunctionDef format.
|
||||||
Status AsFunctionDefNode(NodeDef* function_body_node) const;
|
Status AsFunctionDefNode(NodeDef* function_body_node) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Mapping from input name to input arg expansion.
|
// Mapping from input name to input arg expansion.
|
||||||
std::unordered_map<string, InputArgExpansion> input_arg_expansions_;
|
absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
|
||||||
// Mapping from function body node name to output names range map.
|
// Mapping from function body node name to output names range map.
|
||||||
std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_;
|
absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
|
||||||
|
|
||||||
|
// For each placeholder added to the function instantiation graph, we keep a
|
||||||
|
// mapping back to the function input argument name and index.
|
||||||
struct InputArgPlaceholder {
|
struct InputArgPlaceholder {
|
||||||
string input_name; // Name of the function input argument.
|
string input_name; // Name of the function input argument.
|
||||||
int input_position; // Index of a tensor in the function input argument
|
int input_index; // Index of a tensor in the function input argument
|
||||||
// expansion, it can be greater than `0` if input
|
// expansion, it can be greater than `0` if input
|
||||||
// argument is a list of tensors (aka list(type)).
|
// argument is a list of tensors (aka list(type)).
|
||||||
};
|
};
|
||||||
|
|
||||||
// Mapping from input arg placeholder to the function input tensor.
|
// Mapping from input arg placeholder to the function input tensor.
|
||||||
std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_;
|
absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get Function type attributes using attributes of a node that instantiated
|
// Get Function type attributes using attributes of a node that instantiated
|
||||||
@ -172,7 +169,8 @@ class GrapplerFunctionItem : public GrapplerItem {
|
|||||||
friend Status ReplaceInputWithConst(const NodeDef&, int,
|
friend Status ReplaceInputWithConst(const NodeDef&, int,
|
||||||
GrapplerFunctionItem*);
|
GrapplerFunctionItem*);
|
||||||
friend Status RemoveUnusedOutputs(
|
friend Status RemoveUnusedOutputs(
|
||||||
const gtl::FlatSet<int>& active_outputs, GrapplerFunctionItem* item,
|
const absl::flat_hash_set<int>& active_outputs,
|
||||||
|
GrapplerFunctionItem* item,
|
||||||
std::vector<std::pair<int, int>>* output_mapping);
|
std::vector<std::pair<int, int>>* output_mapping);
|
||||||
|
|
||||||
GrapplerFunctionItem(string func_name, string description,
|
GrapplerFunctionItem(string func_name, string description,
|
||||||
@ -191,7 +189,7 @@ class GrapplerFunctionItem : public GrapplerItem {
|
|||||||
|
|
||||||
std::set<string> input_arg_placeholders_;
|
std::set<string> input_arg_placeholders_;
|
||||||
|
|
||||||
bool is_stateful_;
|
bool is_stateful_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if function input/output types are fully defined only at instantiation
|
// Check if function input/output types are fully defined only at instantiation
|
||||||
@ -210,14 +208,14 @@ bool IsParametrized(const FunctionDef& func);
|
|||||||
// caller node. Return error if type can't be resolved.
|
// caller node. Return error if type can't be resolved.
|
||||||
Status InstantiationTypeParameters(
|
Status InstantiationTypeParameters(
|
||||||
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
||||||
std::unordered_map<string, DataType>* type_parameters);
|
absl::flat_hash_map<string, DataType>* type_parameters);
|
||||||
|
|
||||||
// Resolve function instantiation body parameters (values for the function body
|
// Resolve function instantiation body parameters (values for the function body
|
||||||
// attr placeholders) from the attributes of the caller node. Return error if
|
// attr placeholders) from the attributes of the caller node. Return error if
|
||||||
// type can't be resolved.
|
// type can't be resolved.
|
||||||
Status InstantiationBodyParameters(
|
Status InstantiationBodyParameters(
|
||||||
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
const FunctionDef& func, const AttrSlice& func_instantiation_attr,
|
||||||
std::unordered_map<string, AttrValue>* body_parameters);
|
absl::flat_hash_map<string, AttrValue>* body_parameters);
|
||||||
|
|
||||||
// Register GrapplerFunctionItem input arg expansion and function body outputs
|
// Register GrapplerFunctionItem input arg expansion and function body outputs
|
||||||
// in the GrapplerFunctionConnectivity. Use function library definition to
|
// in the GrapplerFunctionConnectivity. Use function library definition to
|
||||||
@ -227,7 +225,7 @@ Status RegisterGrapplerFunctionConnectivity(
|
|||||||
GrapplerFunctionConnectivity* connectivity);
|
GrapplerFunctionConnectivity* connectivity);
|
||||||
|
|
||||||
// Replace one of the function inputs with a constant.
|
// Replace one of the function inputs with a constant.
|
||||||
Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
|
||||||
GrapplerFunctionItem* item);
|
GrapplerFunctionItem* item);
|
||||||
|
|
||||||
// Remove function output arguments that do not have any active outputs (output
|
// Remove function output arguments that do not have any active outputs (output
|
||||||
@ -236,7 +234,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
|
|||||||
// potentially be connected to the same output argument (in case of tensor list
|
// potentially be connected to the same output argument (in case of tensor list
|
||||||
// outputs). Add output mapping for all active outputs that changed it's output
|
// outputs). Add output mapping for all active outputs that changed it's output
|
||||||
// position (std::pair<old position, new position>).
|
// position (std::pair<old position, new position>).
|
||||||
Status RemoveUnusedOutputs(const gtl::FlatSet<int>& active_outputs,
|
Status RemoveUnusedOutputs(const absl::flat_hash_set<int>& active_outputs,
|
||||||
GrapplerFunctionItem* item,
|
GrapplerFunctionItem* item,
|
||||||
std::vector<std::pair<int, int>>* output_mapping);
|
std::vector<std::pair<int, int>>* output_mapping);
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/core/grappler/utils/functions.h"
|
#include "tensorflow/core/grappler/utils/functions.h"
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/node_def.pb.h"
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
@ -77,7 +79,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
|
|||||||
func_instantiation_attr["B"].set_type(DT_INT32);
|
func_instantiation_attr["B"].set_type(DT_INT32);
|
||||||
func_instantiation_attr["C"].set_type(DT_DOUBLE);
|
func_instantiation_attr["C"].set_type(DT_DOUBLE);
|
||||||
|
|
||||||
std::unordered_map<string, DataType> type_parameters;
|
absl::flat_hash_map<string, DataType> type_parameters;
|
||||||
TF_EXPECT_OK(InstantiationTypeParameters(
|
TF_EXPECT_OK(InstantiationTypeParameters(
|
||||||
func, AttrSlice(&func_instantiation_attr), &type_parameters));
|
func, AttrSlice(&func_instantiation_attr), &type_parameters));
|
||||||
|
|
||||||
@ -86,7 +88,7 @@ TEST_F(FunctionsTest, InstantiationParameters) {
|
|||||||
EXPECT_EQ(DT_INT32, type_parameters["B"]);
|
EXPECT_EQ(DT_INT32, type_parameters["B"]);
|
||||||
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
|
EXPECT_EQ(DT_DOUBLE, type_parameters["C"]);
|
||||||
|
|
||||||
std::unordered_map<string, AttrValue> body_parameters;
|
absl::flat_hash_map<string, AttrValue> body_parameters;
|
||||||
TF_EXPECT_OK(InstantiationBodyParameters(
|
TF_EXPECT_OK(InstantiationBodyParameters(
|
||||||
func, AttrSlice(&func_instantiation_attr), &body_parameters));
|
func, AttrSlice(&func_instantiation_attr), &body_parameters));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user