STT-tensorflow/tensorflow/compiler/tf2xla/resource_util.cc

333 lines
12 KiB
C++

/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/resource_util.h"
#include <string>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
namespace {
using stream_executor::port::StatusOr;
const char kIdentityNOp[] = "IdentityN";
const char kIfOp[] = "If";
const char kWhileOp[] = "While";
const char kArgOp[] = "_Arg";
const char kRetvalOp[] = "_Retval";
const int kMaxCallDepth = 100;
Status AnalyzeResourceUsage(
const Graph* graph, const absl::optional<std::string>& function_name,
const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path);
bool IsControlFlowV1Node(const Node* n) {
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
n->IsNextIteration());
}
// TODO(ycao): Add this as Tensorflow Node method.
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node& n,
int idx) {
absl::InlinedVector<const Edge*, 1> res;
if (idx >= n.num_outputs()) {
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
n.name(), " only has ", n.num_outputs(),
" outputs.");
}
for (const Edge* o : n.out_edges()) {
if (o->src_output() == idx) res.emplace_back(o);
}
return res;
}
bool IsStackOrTensorArraySource(const Node& n) {
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
if (!op_info) return false;
if (op_info->resource_kind() != XlaResourceKind::kStack &&
op_info->resource_kind() != XlaResourceKind::kTensorArray)
return false;
return n.num_outputs() > 0 && n.output_type(0) == DataType::DT_RESOURCE;
}
void PropagateFromStackOrTensorArraySourceOp(
const Node& n, const absl::optional<std::string>& function_name,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
n.type_string());
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
(*user_to_source)[o] = src_node_info;
}
}
Status PropagateFromArgOp(
const Node& n, const absl::optional<std::string>& function_name,
const absl::flat_hash_set<int>& resource_arg_indices,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
TF_RET_CHECK(n.type_string() == kArgOp);
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", &index));
if (!resource_arg_indices.contains(index)) return Status::OK();
TF_RET_CHECK(function_name.has_value())
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
"carrying Stack/TensorArray resource in given graph unless they "
"are in function calls.";
const ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n.name(),
n.type_string());
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
(*user_to_source)[o] = src_node_info;
}
return Status::OK();
}
Status UpdateResourceUsageFromFunctionBodyAnalysis(
const Node& call_node,
const absl::optional<absl::string_view>& caller_function_name,
const FunctionBody& fbody,
const absl::flat_hash_map<
ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>&
called_function_source_to_path,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
caller_source_to_path) {
std::unordered_map<std::string, Node*> node_name_index =
fbody.graph->BuildNodeNameIndex();
for (const auto& it : called_function_source_to_path) {
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
// If source is an _Arg, then the true source is actually corresponding
// edge that feeds into function call node with the same index.
if (src_node_info.op_ == kArgOp) {
const Node* arg_src = node_name_index[src_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
const Edge* e;
// TODO(ycao): Allow overriding input_edge to _Arg index mapping. This is
// needed for cond function of while nodes.
TF_RETURN_IF_ERROR(call_node.input_edge(index, &e));
src_node_info = (*user_to_source)[e];
}
for (const auto& dst_node_info : it.second) {
// If user is an _Retval, then the true user is actually corresponding
// edge of that _Retval.
if (dst_node_info.op_ == kRetvalOp) {
const Node* ret_user = node_name_index[dst_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
absl::InlinedVector<const Edge*, 1> outs;
// TODO(ycao): Allow overriding _Retval index to call node output edge
// mapping. This is needed for cond function of while nodes.
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(call_node, index));
for (const Edge* o : outs) (*user_to_source)[o] = src_node_info;
} else {
(*caller_source_to_path)[src_node_info].emplace(dst_node_info);
}
}
}
return Status::OK();
}
Status PropagateThroughCallOp(
const Node& n, const absl::optional<std::string>& function_name,
const int call_depth, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path) {
if (call_depth > kMaxCallDepth) {
return errors::InvalidArgument(
"Function call stack in given graph is too deep, last function ",
"name is: ", function_name.value());
}
// resource_arg_indices contains all indices of the input
// arguments that carry Stack/TensorArray resource handles.
absl::flat_hash_set<int> resource_arg_indices;
for (const Edge* e : n.in_edges()) {
if (user_to_source->contains(e)) {
resource_arg_indices.emplace(e->dst_input());
}
}
// Instantiate associated function to get function body.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(InstantiateFunctionCall(n.def(), lib_runtime, &handle));
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
// Recursively analyze called function for resource sources and users.
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
called_function_source_to_path;
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
fbody->graph, n.type_string(), call_depth + 1, resource_arg_indices,
lib_runtime, &called_function_source_to_path));
TF_RETURN_IF_ERROR(UpdateResourceUsageFromFunctionBodyAnalysis(
n, function_name, *fbody, called_function_source_to_path, user_to_source,
source_to_path));
return Status::OK();
}
// Analyzes pass through values for Identity and IdentityN ops.
Status PropagateThroughIdentityOp(
const Node& n,
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>*
user_to_source) {
TF_RET_CHECK(n.IsIdentity() || n.type_string() == kIdentityNOp);
if (n.IsIdentity()) {
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
const Edge* in;
TF_RETURN_IF_ERROR(n.input_edge(0, &in));
if (!user_to_source->contains(in)) continue;
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
}
} else {
for (const Edge* o : n.out_edges()) {
if (o->IsControlEdge()) continue;
const Edge* in;
TF_RETURN_IF_ERROR(n.input_edge(o->src_output(), &in));
if (!user_to_source->contains(in)) continue;
user_to_source->emplace(std::make_pair(o, (*user_to_source)[in]));
}
}
return Status::OK();
}
Status AnalyzeResourceUsage(
const Graph* graph, const absl::optional<std::string>& function_name,
const int call_depth, const absl::flat_hash_set<int>& resource_arg_indices,
FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path) {
source_to_path->clear();
std::vector<Node*> reverse_post_order;
GetReversePostOrder(*graph, &reverse_post_order, NodeComparatorName{});
// user_to_source maps from an edge carrying a Stack or TensorArray resource
// to the node that created this resource.
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>
user_to_source;
for (const Node* n : reverse_post_order) {
if (IsControlFlowV1Node(n)) {
return errors::InvalidArgument(
"AnalyzeResourceUsage does not support control flow v1 node: ",
n->DebugString());
}
// TODO(ycao): Support pass-through functional while/if nodes.
if (n->type_string() == kIfOp || n->type_string() == kWhileOp) {
return errors::InvalidArgument(
"AnalyzeResourceUsage does not yet support control flow v2 "
"node: ",
n->DebugString());
}
// Record a resource source edge.
if (IsStackOrTensorArraySource(*n)) {
PropagateFromStackOrTensorArraySourceOp(*n, function_name,
&user_to_source);
continue;
}
// Arguments that are listed in resource_arg_indices are also considered as
// resource sources.
if (n->IsArg()) {
TF_RETURN_IF_ERROR(PropagateFromArgOp(
*n, function_name, resource_arg_indices, &user_to_source));
continue;
}
// Recursively analyze function call ops.
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
TF_RETURN_IF_ERROR(PropagateThroughCallOp(*n, function_name, call_depth,
lib_runtime, &user_to_source,
source_to_path));
continue;
}
if (n->IsIdentity() || n->type_string() == kIdentityNOp) {
TF_RETURN_IF_ERROR(PropagateThroughIdentityOp(*n, &user_to_source));
}
}
for (const auto& it : user_to_source) {
(*source_to_path)[it.second].emplace(function_name, it.first->dst()->name(),
it.first->dst()->type_string());
}
return Status::OK();
}
} // anonymous namespace
/*Static*/ Status ResourceUsageAnalysis::Analyze(
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
source_to_path) {
return AnalyzeResourceUsage(
graph, /*function_name=*/{}, /*call_depth=*/0,
/*resource_arg_indices=*/absl::flat_hash_set<int>(), lib_runtime,
source_to_path);
}
} // namespace tensorflow