927 lines
34 KiB
C++
927 lines
34 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
|
|
|
#include <functional>
|
|
#include <queue>
|
|
#include <random>
|
|
#include <set>
|
|
#include <unordered_map>
|
|
|
|
#include "absl/strings/str_cat.h"
|
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
|
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/common_runtime/function.h"
|
|
#include "tensorflow/core/framework/graph.pb.h"
|
|
#include "tensorflow/core/framework/graph_def_util.h"
|
|
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
|
#include "tensorflow/core/framework/node_def.pb.h"
|
|
#include "tensorflow/core/framework/node_def_builder.h"
|
|
#include "tensorflow/core/framework/node_def_util.h"
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
|
#include "tensorflow/core/framework/versions.pb.h"
|
|
#include "tensorflow/core/graph/tensor_id.h"
|
|
#include "tensorflow/core/lib/core/errors.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
namespace {
|
|
|
|
Status ValidateTensorId(const tf2xla::TensorId& id) {
|
|
if (id.node_name().empty()) {
|
|
return errors::InvalidArgument("TensorId node_name must be non-empty");
|
|
}
|
|
if (id.output_index() < 0) {
|
|
return errors::InvalidArgument("TensorId output_index must be positive");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status CheckNameDuplicates(const string& kind, const string& name,
|
|
std::set<string>* names) {
|
|
if (!name.empty()) {
|
|
if (!names->insert(name).second) {
|
|
return errors::InvalidArgument("duplicate ", kind, " name: ", name);
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status CheckFeedFetchNameConflicts(const string& kind,
|
|
const std::set<string>& names) {
|
|
// We don't allow the feeds or fetches to contain both "foo" and "foo_data",
|
|
// since that will cause a collision in codegen symbols.
|
|
for (const string& name : names) {
|
|
const string name_data(name + "_data");
|
|
if (names.find(name_data) != names.end()) {
|
|
return errors::InvalidArgument("conflicting ", kind, " name: ", name,
|
|
" and ", name_data);
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
|
|
// `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
|
|
Status CopyAssociatedFunctions(Graph* g,
|
|
const FunctionLibraryDefinition* lookup_fld,
|
|
FunctionLibraryDefinition* fld) {
|
|
for (Node* n : g->op_nodes()) {
|
|
for (const auto& associated_function :
|
|
GetAssociatedFunctions(*n, lookup_fld)) {
|
|
switch (associated_function.type()) {
|
|
case AssociatedFunctionInfo::kFunctionCallNode: {
|
|
const FunctionDef* fdef =
|
|
lookup_fld->Find(associated_function.func_name());
|
|
if (!fdef) {
|
|
return errors::Internal(
|
|
"Cannot find function ", associated_function.func_name(),
|
|
" for function call node ", n->DebugString());
|
|
}
|
|
TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
|
|
break;
|
|
}
|
|
case AssociatedFunctionInfo::kSymbolicGradient:
|
|
case AssociatedFunctionInfo::kFunctionAttr:
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// For graph `g`, replaces _Arg nodes whose "index" attribute is in
|
|
// `const_input_index_to_node` with Const nodes.
|
|
Status ReplaceArgUsageWithConstNode(
|
|
Graph* g,
|
|
const std::unordered_map<int, const Node*>& const_input_index_to_node) {
|
|
// Collect all _Arg nodes.
|
|
std::unordered_map<int, Node*> arg_nodes;
|
|
for (Node* n : g->op_nodes()) {
|
|
if (n->IsArg()) {
|
|
int index;
|
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
|
arg_nodes[index] = n;
|
|
}
|
|
}
|
|
|
|
for (const auto& iter : const_input_index_to_node) {
|
|
int arg_index = iter.first;
|
|
NodeDef const_def = iter.second->def();
|
|
const_def.set_name(g->NewName(const_def.name()));
|
|
Status s;
|
|
Node* const_node = g->AddNode(const_def, &s);
|
|
TF_RETURN_IF_ERROR(s);
|
|
|
|
Node* arg_node = arg_nodes[arg_index];
|
|
|
|
// Collect all usages of the _Arg node.
|
|
struct OutEdgeInfo {
|
|
int dst_node_id, dst_input;
|
|
};
|
|
std::vector<OutEdgeInfo> usages;
|
|
for (const Edge* e : arg_node->out_edges()) {
|
|
if (e->IsControlEdge()) {
|
|
continue;
|
|
}
|
|
usages.push_back({e->dst()->id(), e->dst_input()});
|
|
}
|
|
|
|
for (int i = 0; i < usages.size(); i++) {
|
|
// Make a copy of `usage_node`, and change its input to const node.
|
|
Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
|
|
NodeDef replace_def = usage_node->def();
|
|
*replace_def.mutable_input(usages[i].dst_input) = const_node->name();
|
|
TF_ASSIGN_OR_RETURN(Node * replace_node,
|
|
ReplaceNode(g, usage_node, replace_def));
|
|
const Edge* usage_edge;
|
|
TF_RETURN_IF_ERROR(
|
|
replace_node->input_edge(usages[i].dst_input, &usage_edge));
|
|
g->RemoveEdge(usage_edge);
|
|
g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
|
|
|
|
// Later entries in `usages` might have `usage_node` as dst node, but
|
|
// `usage_node` is removed. Replace such entries with `replace_node`.
|
|
for (int j = i + 1; j < usages.size(); j++) {
|
|
if (usages[j].dst_node_id == usages[i].dst_node_id) {
|
|
usages[j].dst_node_id = replace_node->id();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
// For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
|
|
// the function to replace _Arg nodes in `const_input_index_to_node` with Const
|
|
// inputs.
|
|
Status PropagateConstIntoFuncAttr(
|
|
Node* n, const string& attr_name,
|
|
const std::unordered_map<int, const Node*>& const_input_index_to_node,
|
|
const FunctionLibraryDefinition* lookup_fld,
|
|
FunctionLibraryDefinition* fld) {
|
|
// Instantiate the function.
|
|
NameAttrList func_attr;
|
|
TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
|
|
const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
|
|
if (!fdef) {
|
|
return errors::Internal("Cannot find function ", func_attr.name(),
|
|
" for node ", n->name());
|
|
}
|
|
std::unique_ptr<FunctionBody> fbody;
|
|
TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
|
|
*fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
|
|
|
|
// Rewrite _Arg usages with Const node.
|
|
Graph* func_graph = fbody->graph;
|
|
TF_RETURN_IF_ERROR(
|
|
ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
|
|
|
|
// Save rewritten function.
|
|
FunctionDef replace_fdef;
|
|
string new_func_name =
|
|
fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
|
|
TF_RETURN_IF_ERROR(
|
|
GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
|
|
TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
|
|
|
|
// Change the node to use rewritten function.
|
|
func_attr.set_name(new_func_name);
|
|
n->ClearAttr(attr_name);
|
|
n->AddAttr(attr_name, func_attr);
|
|
|
|
// Copy associated functions.
|
|
TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// For an "If" node in graph `g`, if it has Const node inputs, rewrite its
|
|
// then/else branch function to replace _Arg nodes with those Const inputs.
|
|
Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
|
|
const FunctionLibraryDefinition* lookup_fld,
|
|
FunctionLibraryDefinition* fld) {
|
|
// Notice that first input for If node is predicate; other inputs are function
|
|
// inputs.
|
|
std::unordered_map<int, const Node*> const_input_index_to_node;
|
|
for (int i = 1; i < if_node->num_inputs(); i++) {
|
|
const Node* input_node;
|
|
TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
|
|
if (input_node->type_string() == "Const") {
|
|
const_input_index_to_node[i - 1] = input_node;
|
|
}
|
|
}
|
|
if (const_input_index_to_node.empty()) {
|
|
return Status::OK();
|
|
}
|
|
|
|
// Rewrite "then_branch" and "else_branch" function, replace usage of those
|
|
// _Arg nodes with corresponding const node.
|
|
for (const auto& attr_name :
|
|
std::vector<string>{"then_branch", "else_branch"}) {
|
|
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
|
|
if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
// For a "While" node in graph `g`, if it has Const node inputs, rewrite its
|
|
// cond/body function to replace _Arg nodes with those Const inputs.
|
|
Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
|
|
const FunctionLibraryDefinition* lookup_fld,
|
|
FunctionLibraryDefinition* fld) {
|
|
// For "While" node, we should only replace _Arg nodes which are loop
|
|
// invariants. For such _Arg nodes, the return value's input will come
|
|
// directly from the corresponding arg.
|
|
std::unordered_map<int, const Node*> const_input_index_to_node;
|
|
NameAttrList body_attr;
|
|
TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
|
|
const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
|
|
if (!body_func) {
|
|
return errors::Internal("Cannot find body function ", body_attr.name(),
|
|
" for While node ", while_node->name());
|
|
}
|
|
for (int i = 0; i < while_node->num_inputs(); i++) {
|
|
const Node* input_node;
|
|
TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
|
|
if (input_node->type_string() != "Const") {
|
|
continue;
|
|
}
|
|
|
|
// Check if i-th retval's input comes from i-th arg directly.
|
|
// For resource variable input of While nodes, TF2XLA convention is to place
|
|
// them at the end of all inputs (after all data inputs), and *not* return
|
|
// them. So number of While node inputs might be larger than number of its
|
|
// outputs.
|
|
if (i >= body_func->signature().output_arg_size()) {
|
|
continue;
|
|
}
|
|
const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
|
|
auto output_arg_input = body_func->ret().find(output_arg.name());
|
|
if (output_arg_input == body_func->ret().end()) {
|
|
return errors::Internal("Cannot find input for output arg ",
|
|
output_arg.name(), " in function ",
|
|
body_attr.name());
|
|
}
|
|
const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
|
|
if (output_arg_input->second != input_arg.name()) {
|
|
continue;
|
|
}
|
|
|
|
const_input_index_to_node[i] = input_node;
|
|
}
|
|
if (const_input_index_to_node.empty()) {
|
|
return Status::OK();
|
|
}
|
|
|
|
// Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
|
|
// corresponding const node.
|
|
for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
|
|
TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
|
|
while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
|
|
|
|
Status ValidateConfig(const tf2xla::Config& config) {
|
|
std::set<string> names;
|
|
for (const tf2xla::Feed& feed : config.feed()) {
|
|
TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
|
|
TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
|
|
TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
|
|
}
|
|
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
|
|
names.clear();
|
|
for (const tf2xla::Fetch& fetch : config.fetch()) {
|
|
TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
|
|
TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
|
|
}
|
|
TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
|
|
if (config.fetch().empty()) {
|
|
return errors::InvalidArgument("fetches must be specified");
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status AddPlaceholdersForFeeds(
|
|
const tf2xla::Config& config, const OpRegistryInterface* op_registry,
|
|
std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
|
|
struct PlaceholderInfo {
|
|
const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
|
|
string placeholder_name;
|
|
DataType data_type = DT_INVALID;
|
|
};
|
|
|
|
// Put each fed tensor into a map by name:port. A map is used for determinism
|
|
// when creating placeholders (genrules want deterministic output).
|
|
std::map<string, PlaceholderInfo> placeholder_info;
|
|
for (int i = 0; i < config.feed_size(); ++i) {
|
|
const tf2xla::Feed* feed = &config.feed(i);
|
|
const string name_port = TensorIdToString(feed->id());
|
|
PlaceholderInfo& info = placeholder_info[name_port];
|
|
info.feed = feed;
|
|
info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
|
|
"/", feed->id().node_name());
|
|
(*feed_remapping)[name_port] = info.placeholder_name;
|
|
}
|
|
|
|
// Verify node exists and determine data type.
|
|
std::unordered_map<string, const NodeDef*> name_to_node;
|
|
for (int i = 0; i < graph_def->node_size(); ++i) {
|
|
name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
|
|
}
|
|
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
|
PlaceholderInfo& info = it->second;
|
|
const tf2xla::TensorId& feed_id = info.feed->id();
|
|
|
|
// Find the existing node and determine data type.
|
|
auto node_it = name_to_node.find(feed_id.node_name());
|
|
if (node_it == name_to_node.end()) {
|
|
return errors::NotFound("Can't find feed node: ",
|
|
TensorIdToString(feed_id));
|
|
}
|
|
const NodeDef* existing = node_it->second;
|
|
|
|
if (info.feed->type() != DT_INVALID) {
|
|
info.data_type = info.feed->type();
|
|
} else {
|
|
// Build the node in order to infer its type.
|
|
|
|
// Must first add default attrs as well, so do this in a copied GraphDef.
|
|
GraphDef gd;
|
|
*gd.mutable_versions() = graph_def->versions();
|
|
*gd.add_node() = *existing;
|
|
MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
|
|
TF_RETURN_IF_ERROR(
|
|
AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
|
|
|
|
// Now build the node from the copied node def.
|
|
Graph g(op_registry);
|
|
g.set_versions(graph_def->versions());
|
|
Status status;
|
|
Node* feed_node = g.AddNode(gd.node(0), &status);
|
|
TF_RETURN_IF_ERROR(status);
|
|
|
|
if (info.feed->id().output_index() < feed_node->num_outputs()) {
|
|
info.data_type =
|
|
BaseType(feed_node->output_type(info.feed->id().output_index()));
|
|
} else {
|
|
return errors::InvalidArgument(
|
|
"Invalid output_index ", info.feed->id().output_index(),
|
|
" for feed node ", info.feed->id().node_name());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create placeholders. Note that we could avoid creating a placeholder for
|
|
// feeds which are already placeholders, but we omit that to avoid more cases
|
|
// in this code.
|
|
for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
|
|
const PlaceholderInfo& info = it->second;
|
|
// TODO(shikharagarwal): Add original node information.
|
|
NodeDef* d = graph_def->add_node();
|
|
d->set_name(info.placeholder_name);
|
|
d->set_op("Placeholder");
|
|
auto& attr_map = *d->mutable_attr();
|
|
attr_map["dtype"].set_type(info.data_type);
|
|
*attr_map["shape"].mutable_shape() = info.feed->shape();
|
|
}
|
|
|
|
// Rewrite references to the fed tensors to refer to the placeholder.
|
|
for (int i = 0; i < graph_def->node_size(); ++i) {
|
|
NodeDef* node_def = graph_def->mutable_node(i);
|
|
for (int j = 0; j < node_def->input_size(); ++j) {
|
|
auto id = ParseTensorName(node_def->input(j));
|
|
auto it = placeholder_info.find(id.ToString());
|
|
if (it != placeholder_info.end()) {
|
|
node_def->set_input(j, it->second.placeholder_name);
|
|
}
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
|
|
GraphDef* out) {
|
|
*out = in;
|
|
out->clear_node();
|
|
|
|
// Tensors needed for feeding.
|
|
std::set<std::pair<string, int>> feed_tensors;
|
|
for (const tf2xla::Feed& feed : config.feed()) {
|
|
feed_tensors.insert(
|
|
std::make_pair(feed.id().node_name(), feed.id().output_index()));
|
|
}
|
|
|
|
// Maps node name to reachability.
|
|
std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
|
|
for (const NodeDef& node : in.node()) {
|
|
node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
|
|
}
|
|
|
|
// Traverse.
|
|
std::queue<string> name_queue;
|
|
for (int i = 0; i < config.fetch_size(); ++i) {
|
|
name_queue.push(config.fetch(i).id().node_name());
|
|
}
|
|
while (!name_queue.empty()) {
|
|
const string name = name_queue.front();
|
|
name_queue.pop();
|
|
|
|
auto find_it = node_by_name.find(name);
|
|
if (find_it == node_by_name.end()) {
|
|
return errors::InvalidArgument("While pruning graph, node ", name,
|
|
" needed but not found in the graph.");
|
|
}
|
|
auto& map_entry = find_it->second;
|
|
if (map_entry.first) {
|
|
continue;
|
|
}
|
|
map_entry.first = true;
|
|
|
|
// Push input nodes of the currently visited node to name_queue.
|
|
for (const string& in_edge : map_entry.second->input()) {
|
|
auto id = ParseTensorName(in_edge);
|
|
const string node_name = string(id.first);
|
|
if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
|
|
feed_tensors.end()) {
|
|
name_queue.push(node_name);
|
|
} else {
|
|
// The input tensor is from an edge that is being fed. Therefore,
|
|
// we skip recursing down that edge, to avoid requiring nodes that
|
|
// may not be needed (note that the input node may still be added
|
|
// to name_queue later if one of its output edges is not being fed).
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy over, preserving order of original and only nodes that are reachable
|
|
// from the fetches.
|
|
out->mutable_node()->Reserve(in.node_size());
|
|
for (const NodeDef& node : in.node()) {
|
|
if (node_by_name[node.name()].first) {
|
|
*out->add_node() = node;
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
string TensorIdToString(const tf2xla::TensorId& id) {
|
|
return absl::StrCat(id.node_name(), ":", id.output_index());
|
|
}
|
|
|
|
Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
|
|
int core = -1;
|
|
const Node* matching_node = nullptr;
|
|
for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
|
|
if (edge->IsControlEdge()) continue;
|
|
const Node* possible_match = out_edges ? edge->dst() : edge->src();
|
|
TF_ASSIGN_OR_RETURN(
|
|
absl::optional<xla::OpSharding> sharding,
|
|
ParseShardingFromDevice(
|
|
*possible_match,
|
|
/*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
|
|
if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
|
|
const int core_annotation = sharding.value().tile_assignment_devices(0);
|
|
if (core == -1 || core > core_annotation) {
|
|
core = core_annotation;
|
|
matching_node = possible_match;
|
|
}
|
|
}
|
|
}
|
|
if (matching_node != nullptr) {
|
|
n->set_assigned_device_name(matching_node->assigned_device_name());
|
|
n->set_requested_device(matching_node->requested_device());
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
|
|
KernelDef* kdef) {
|
|
for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
|
|
if (constraint.name() == name) {
|
|
constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
uint32 InitialRandomSeed() {
|
|
// Support plumbing the TF seed through to XLA is being worked on.
|
|
// If a user wants deterministic behavior, their best option
|
|
// is to start with a known checkpoint. This also handles issues when
|
|
// multiple random calls can be invoked in any order by TF executor.
|
|
// Another option is to use stateless random ops. They have much cleaner
|
|
// semantics.
|
|
// If a user really wants to set a deterministic seed for XLA-based
|
|
// devices, this is the place to do it.
|
|
std::random_device rd;
|
|
// Make the starting value odd.
|
|
return rd() | 1;
|
|
}
|
|
} // namespace
|
|
|
|
uint32 GetXLARandomSeed() {
|
|
// We initialize counter with an odd number and increment it by two
|
|
// everytime. This ensures that it will never be zero, even
|
|
// after an overflow. When seeded with zero, some XLA backends
|
|
// can return all zeros instead of random numbers.
|
|
static std::atomic<uint32> counter(InitialRandomSeed());
|
|
uint32 seed = counter.fetch_add(2);
|
|
std::srand(seed);
|
|
return std::rand() | 1;
|
|
}
|
|
|
|
// TODO(b/77601805): add tests for associated function related stuff.
|
|
bool HasAssociatedFunction(const NodeDef& node_def,
|
|
const FunctionLibraryDefinition* fld) {
|
|
if (fld->Contains(node_def.op())) {
|
|
return true;
|
|
}
|
|
|
|
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
|
|
// Gradient op has "f" attr, which is set to the function we are getting
|
|
// gradient for. We need to functionalize the gradient function.
|
|
return true;
|
|
}
|
|
|
|
if (node_def.op() == "XlaHostCompute") {
|
|
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
|
|
// related to graph execution.
|
|
return false;
|
|
}
|
|
|
|
for (const auto& iter : node_def.attr()) {
|
|
if (iter.second.has_func()) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
|
|
const Node& node, const FunctionLibraryDefinition* fld) {
|
|
std::vector<AssociatedFunctionInfo> results;
|
|
const string& op = node.type_string();
|
|
if (fld->Contains(op)) {
|
|
// This is a function call node.
|
|
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
|
|
results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
|
|
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
|
|
// This is a SymbolicGradient op.
|
|
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
|
|
results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
|
|
} else if (node.type_string() == "XlaHostCompute") {
|
|
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
|
|
// related to graph execution.
|
|
} else {
|
|
// Collect all function attrs for the node.
|
|
for (auto& iter : node.attrs()) {
|
|
if (iter.second.has_func()) {
|
|
VLOG(2) << "Found function attr for node " << node.name() << ": "
|
|
<< iter.first << " = " << iter.second.func().name();
|
|
results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
|
|
iter.second.func().name(), iter.second.func().attr(), iter.first));
|
|
}
|
|
}
|
|
}
|
|
return results;
|
|
}
|
|
|
|
Status RewriteAssociatedFunction(
|
|
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
|
|
const AssociatedFunctionInfo& associated_function,
|
|
const string& rewritten_function_name) {
|
|
switch (associated_function.type()) {
|
|
case AssociatedFunctionInfo::kFunctionCallNode: {
|
|
// Change this node to call the new function.
|
|
NodeDebugInfo debug_info(*node);
|
|
NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
|
|
&debug_info);
|
|
for (const auto& attr : node->attrs()) {
|
|
builder.Attr(attr.first, attr.second);
|
|
}
|
|
for (int i = 0; i < node->num_inputs(); i++) {
|
|
Node* input_node;
|
|
TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
|
|
builder.Input(input_node->name(), i, node->input_type(i));
|
|
}
|
|
builder.Device(node->assigned_device_name().empty()
|
|
? node->requested_device()
|
|
: node->assigned_device_name());
|
|
NodeDef node_def;
|
|
TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
|
|
Status s;
|
|
Node* new_node = graph->AddNode(node_def, &s);
|
|
TF_RETURN_IF_ERROR(s);
|
|
for (auto edge : node->in_edges()) {
|
|
graph->AddEdge(edge->src(), edge->src_output(), new_node,
|
|
edge->dst_input());
|
|
}
|
|
for (auto edge : node->out_edges()) {
|
|
graph->AddEdge(new_node, edge->src_output(), edge->dst(),
|
|
edge->dst_input());
|
|
}
|
|
graph->RemoveNode(node);
|
|
break;
|
|
}
|
|
case AssociatedFunctionInfo::kSymbolicGradient: {
|
|
NameAttrList func;
|
|
TF_RETURN_IF_ERROR(GetNodeAttr(
|
|
node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
|
|
GradientDef gradient_def;
|
|
gradient_def.set_function_name(func.name());
|
|
gradient_def.set_gradient_func(rewritten_function_name);
|
|
string original_grad_func = fld->FindGradient(func.name());
|
|
if (original_grad_func.empty()) {
|
|
TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
|
|
} else if (original_grad_func != rewritten_function_name) {
|
|
TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
|
|
}
|
|
break;
|
|
}
|
|
case AssociatedFunctionInfo::kFunctionAttr: {
|
|
// Change function attr to rewritten functions.
|
|
NameAttrList func;
|
|
TF_RETURN_IF_ERROR(
|
|
GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
|
|
node->ClearAttr(associated_function.attr_name());
|
|
func.set_name(rewritten_function_name);
|
|
node->AddAttr(associated_function.attr_name(), func);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
Status CachedFunctionHandles::GetOrInstantiate(
|
|
const string& func_name, AttrSlice attrs,
|
|
FunctionLibraryRuntime::Handle* handle) {
|
|
string canonicalized_name = Canonicalize(func_name, attrs);
|
|
auto iter = handles_.find(canonicalized_name);
|
|
if (iter != handles_.end()) {
|
|
*handle = iter->second;
|
|
return Status::OK();
|
|
}
|
|
|
|
TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
|
|
handles_[canonicalized_name] = *handle;
|
|
return Status::OK();
|
|
}
|
|
|
|
Status CachedFunctionHandles::ReleaseAllHandles() {
|
|
Status result;
|
|
for (const auto& iter : handles_) {
|
|
result.Update(flr_->ReleaseHandle(iter.second));
|
|
}
|
|
handles_.clear();
|
|
return result;
|
|
}
|
|
|
|
xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
|
|
// Create the replacement node.
|
|
Status s;
|
|
Node* new_node = g->AddNode(node_def, &s);
|
|
if (!s.ok()) {
|
|
return s;
|
|
}
|
|
|
|
// Record original node's output edges and remove them first. This is to avoid
|
|
// multiple producers for dst nodes' input.
|
|
std::vector<OutEdgeInfo> out_edge_info;
|
|
std::vector<const Edge*> out_edges;
|
|
for (const Edge* edge : n->out_edges()) {
|
|
out_edges.push_back(edge);
|
|
out_edge_info.push_back(
|
|
{edge->dst(), edge->src_output(), edge->dst_input()});
|
|
}
|
|
for (const Edge* edge : out_edges) {
|
|
g->RemoveEdge(edge);
|
|
}
|
|
|
|
// Add original node's input and output edges to the replacement node.
|
|
for (const Edge* in_edge : n->in_edges()) {
|
|
g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
|
|
in_edge->dst_input());
|
|
}
|
|
for (const OutEdgeInfo& out_edge : out_edge_info) {
|
|
g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
|
|
}
|
|
|
|
// Remove the original node.
|
|
g->RemoveNode(n);
|
|
|
|
return new_node;
|
|
}
|
|
|
|
xla::StatusOr<Node*> BuildIdentityNode(
|
|
Graph* graph, const string& node_name, DataType dtype, const Node* input,
|
|
absl::optional<string> requested_device) {
|
|
// Create identity node.
|
|
NodeDef ndef;
|
|
ndef.set_name(node_name);
|
|
ndef.set_op("Identity");
|
|
if (input) {
|
|
ndef.add_input(input->name());
|
|
}
|
|
if (requested_device) {
|
|
ndef.set_device(*requested_device);
|
|
}
|
|
AddNodeAttr("T", dtype, &ndef);
|
|
Status s;
|
|
Node* id_node = graph->AddNode(ndef, &s);
|
|
TF_RETURN_IF_ERROR(s);
|
|
return id_node;
|
|
}
|
|
|
|
Status PropagateConstIntoFunctionalNodes(
|
|
Graph* g, const FunctionLibraryDefinition* lookup_fld,
|
|
FunctionLibraryDefinition* fld) {
|
|
for (Node* n : g->op_nodes()) {
|
|
if (n->IsIfNode()) {
|
|
TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
|
|
} else if (n->IsWhileNode()) {
|
|
TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status PruneUnreachableFunctionsFromGraph(const Graph& g,
|
|
FunctionLibraryDefinition* fld) {
|
|
GraphDef graph_def;
|
|
g.ToGraphDef(&graph_def);
|
|
FunctionLibraryDefinition reachable_functions =
|
|
fld->ReachableDefinitions(graph_def);
|
|
for (const string& func_name : fld->ListFunctionNames()) {
|
|
if (!reachable_functions.Find(func_name)) {
|
|
TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
|
|
}
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
Status RewriteTensorListWithConstElement(Graph* g,
|
|
FunctionLibraryDefinition* fld) {
|
|
for (Node* n : g->nodes()) {
|
|
if (n->type_string() != "EmptyTensorList") {
|
|
continue;
|
|
}
|
|
|
|
// Find the forward While op.
|
|
std::vector<const Edge*> fwd_while_edges;
|
|
for (const Edge* e : n->out_edges()) {
|
|
if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
|
|
fwd_while_edges.push_back(e);
|
|
}
|
|
}
|
|
if (fwd_while_edges.size() != 1) {
|
|
// No forward While op found, or multiple forward While ops.
|
|
continue;
|
|
}
|
|
|
|
// Find the backward While op.
|
|
Node* fwd_while = fwd_while_edges[0]->dst();
|
|
int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
|
|
std::vector<const Edge*> bwd_while_edges;
|
|
for (const Edge* e : fwd_while->out_edges()) {
|
|
if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
|
|
bwd_while_edges.push_back(e);
|
|
}
|
|
}
|
|
if (bwd_while_edges.size() != 1) {
|
|
// No backward While op found, or multiple backward While ops.
|
|
continue;
|
|
}
|
|
|
|
Node* bwd_while = bwd_while_edges[0]->dst();
|
|
int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
|
|
|
|
// Look into forward While body function and check if TensorListPushBack op
|
|
// has a Const input.
|
|
NameAttrList fwd_body_attr;
|
|
TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
|
|
const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
|
|
if (!fwd_body) {
|
|
return errors::InvalidArgument("Cannot find function ",
|
|
fwd_body_attr.name(), " for While node ",
|
|
fwd_while->DebugString());
|
|
}
|
|
std::unique_ptr<FunctionBody> fwd_fbody;
|
|
TF_CHECK_OK(FunctionDefToBodyHelper(
|
|
*fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
|
|
|
|
// Find the TensorListPushBack node; it's one of fwd_arg's successors.
|
|
Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
|
|
std::vector<Node*> tl_push_nodes;
|
|
for (const Edge* out_edge : fwd_arg->out_edges()) {
|
|
if (out_edge->dst()->type_string() == "TensorListPushBack") {
|
|
tl_push_nodes.push_back(out_edge->dst());
|
|
}
|
|
}
|
|
if (tl_push_nodes.size() != 1) {
|
|
// No TensorListPushBack found, or multiple TensorListPushBack.
|
|
continue;
|
|
}
|
|
|
|
// Get input for the TensorListPushBack node.
|
|
Node* input_node;
|
|
TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
|
|
if (input_node->type_string() != "Const") {
|
|
// Input for the TensorList is not Const node.
|
|
continue;
|
|
}
|
|
|
|
NodeDef const_input_nodedef = input_node->def();
|
|
|
|
// Rewrite backward While body function, replace usages of
|
|
// TensorListPopBack with a Const node.
|
|
NameAttrList bwd_body_attr;
|
|
TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
|
|
const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
|
|
if (!bwd_body) {
|
|
return errors::InvalidArgument("Cannot find function ",
|
|
bwd_body_attr.name(), " for While node ",
|
|
bwd_while->DebugString());
|
|
}
|
|
std::unique_ptr<FunctionBody> bwd_fbody;
|
|
TF_CHECK_OK(FunctionDefToBodyHelper(
|
|
*bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
|
|
|
|
// Find the TensorListPopBack node; it's one of bwd_arg's successors.
|
|
Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
|
|
std::vector<Node*> tl_pop_nodes;
|
|
for (const Edge* out_edge : bwd_arg->out_edges()) {
|
|
if (out_edge->dst()->type_string() == "TensorListPopBack") {
|
|
tl_pop_nodes.push_back(out_edge->dst());
|
|
}
|
|
}
|
|
if (tl_pop_nodes.size() != 1) {
|
|
// No TensorListPopBack found, or multiple TensorListPopBack.
|
|
continue;
|
|
}
|
|
|
|
// Replace TensorListPopBack usages with Const node.
|
|
std::vector<const Edge*> edges_to_replace;
|
|
for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
|
|
if (e->src_output() == 1) {
|
|
edges_to_replace.push_back(e);
|
|
}
|
|
}
|
|
if (edges_to_replace.empty()) {
|
|
continue;
|
|
}
|
|
Status s;
|
|
const_input_nodedef.set_name(
|
|
bwd_fbody->graph->NewName(const_input_nodedef.name()));
|
|
Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
|
|
TF_RETURN_IF_ERROR(s);
|
|
for (const Edge* e : edges_to_replace) {
|
|
Node* dst = e->dst();
|
|
int dst_input = e->dst_input();
|
|
bwd_fbody->graph->RemoveEdge(e);
|
|
bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
|
|
}
|
|
|
|
// Add rewritten backward While body function.
|
|
FunctionDef new_fdef;
|
|
string new_name = fld->UniqueFunctionName(
|
|
absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
|
|
TF_RETURN_IF_ERROR(
|
|
GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
|
|
TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
|
|
|
|
// Change backward While op to use the new body function.
|
|
bwd_body_attr.set_name(new_name);
|
|
bwd_while->ClearAttr("body");
|
|
bwd_while->AddAttr("body", bwd_body_attr);
|
|
}
|
|
return Status::OK();
|
|
}
|
|
|
|
} // namespace tensorflow
|