Add function call support in resource op analysis

PiperOrigin-RevId: 257099969
This commit is contained in:
Yanan Cao 2019-07-08 18:35:25 -07:00 committed by TensorFlower Gardener
parent bdfa6ed2d2
commit 2036c4ad02
4 changed files with 442 additions and 80 deletions

View File

@ -737,12 +737,16 @@ cc_library(
visibility = [":friends"],
deps = [
":resource_operation_table",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
],
)
@ -753,6 +757,7 @@ tf_cuda_cc_test(
deps = [
":resource_util",
"//tensorflow/cc:scope",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",

View File

@ -21,10 +21,12 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
@ -35,6 +37,10 @@ using stream_executor::port::StatusOr;
const char kIdentityNOp[] = "IdentityN";
const char kIfOp[] = "If";
const char kWhileOp[] = "While";
const char kArgOp[] = "_Arg";
const char kRetvalOp[] = "_Retval";
const int kMaxCallDepth = 100;
bool IsControlFlowV1Node(const Node* n) {
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
@ -58,14 +64,27 @@ StatusOr<const Edge*> WalkBackPassThroughEdge(const Edge* e) {
return ret;
}
// TODO(ycao): Support pass-through function calls and functional while/if
// nodes.
// Reaching here means e is not coming from a pass through node, return empty
// vector to indicate we can no longer trace back.
return nullptr;
}
// TODO(ycao): Add this as Tensorflow Node method.
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node* n,
int idx) {
absl::InlinedVector<const Edge*, 1> res;
if (idx >= n->num_outputs()) {
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
n->name(), " only has ", n->num_outputs(),
" outputs.");
}
for (const Edge* o : n->out_edges()) {
if (o->src_output() == idx) res.emplace_back(o);
}
return res;
}
bool IsStackOrTensorArraySource(const Node* n) {
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
@ -76,41 +95,143 @@ bool IsStackOrTensorArraySource(const Node* n) {
return n->num_outputs() > 0 && n->output_type(0) == DataType::DT_RESOURCE;
}
} // anonymous namespace
Status AnalyzeResourceOpSourcePath(
const Graph* graph,
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>*
sources_paths) {
sources_paths->clear();
Status AnalyzeResourceUsage(
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
const absl::optional<std::string>& function_name, const int call_depth,
const absl::flat_hash_set<int>& resource_arg_indices,
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
source_to_path) {
source_to_path->clear();
std::vector<Node*> reverse_post_order;
GetReversePostOrder(*graph, &reverse_post_order, NodeComparatorName{});
// user_to_source maps from an edge carrying a Stack or TensorArray resource
// to the node that created this resource.
absl::flat_hash_map<const Edge*, const Node*> user_to_source;
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>
user_to_source;
for (const Node* n : reverse_post_order) {
if (IsControlFlowV1Node(n)) {
return errors::InvalidArgument(
"AnalyzeResourceOpSourcePath does not support control flow v1 node: ",
"AnalyzeResourceUsage does not support control flow v1 node: ",
n->DebugString());
}
// TODO(ycao): Support pass-through functional while/if nodes.
if (n->type_string() == kIfOp || n->type_string() == kWhileOp) {
return errors::InvalidArgument(
"AnalyzeResourceOpSourcePath does not yet support control flow v2 "
"AnalyzeResourceUsage does not yet support control flow v2 "
"node: ",
n->DebugString());
}
// Record a resource source edge.
if (IsStackOrTensorArraySource(n)) {
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n->name(),
n->type_string());
for (const Edge* o : n->out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE)
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
user_to_source[o] = n;
}
user_to_source[o] = src_node_info;
}
continue;
}
// Arguments that are listed in resource_arg_indices are also considered as
// resource sources.
if (n->IsArg()) {
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
if (!resource_arg_indices.contains(index)) continue;
TF_RET_CHECK(function_name.has_value())
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
"carrying Stack/TensorArray resource in given graph unless they "
"are in function calls.";
const ResourceUsageAnalysis::NodeInfo src_node_info(
function_name, n->name(), n->type_string());
for (const Edge* o : n->out_edges()) {
if (o->IsControlEdge()) continue;
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
continue;
}
user_to_source[o] = src_node_info;
}
continue;
}
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
if (call_depth > kMaxCallDepth) {
return errors::InvalidArgument(
"Function call stack in given graph is too deep, last function ",
"name is: ", function_name.value());
}
// resource_arg_indices_for_call contains all indices of the input
// arguments that carry Stack/TensorArray resource handles.
absl::flat_hash_set<int> resource_arg_indices_for_call;
for (const Edge* e : n->in_edges()) {
if (!user_to_source.contains(e)) continue;
resource_arg_indices_for_call.emplace(e->dst_input());
}
absl::string_view called_function_name = n->type_string();
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(
InstantiateFunctionCall(n->def(), lib_runtime, &handle));
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
// Recursively analyze called function for resource sources and users.
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
called_function_source_to_path;
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
fbody->graph, lib_runtime,
absl::optional<std::string>(called_function_name), call_depth + 1,
resource_arg_indices_for_call, &called_function_source_to_path));
std::unordered_map<std::string, Node*> node_name_index =
fbody->graph->BuildNodeNameIndex();
for (auto it : called_function_source_to_path) {
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
// If source is an _Arg, then the true source is actually corresponding
// edge that feeds into function call node with the same index.
if (src_node_info.op_ == kArgOp) {
const Node* arg_src = node_name_index[src_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
const Edge* e;
TF_RETURN_IF_ERROR(n->input_edge(index, &e));
const Node* true_src = e->src();
src_node_info.function_name_ = function_name;
src_node_info.node_name_ = true_src->name();
src_node_info.op_ = true_src->type_string();
}
for (const auto& dst_node_info : it.second) {
// If user is an _Retval, then the true user is actually corresponding
// edge of that _Retval.
if (dst_node_info.op_ == kRetvalOp) {
const Node* ret_user = node_name_index[dst_node_info.node_name_];
int index;
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
absl::InlinedVector<const Edge*, 1> outs;
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
for (const Edge* o : outs) user_to_source[o] = src_node_info;
} else {
(*source_to_path)[src_node_info].emplace(dst_node_info);
}
}
}
continue;
}
@ -119,15 +240,28 @@ Status AnalyzeResourceOpSourcePath(
if (o->IsControlEdge()) continue;
TF_ASSIGN_OR_RETURN(const Edge* e, WalkBackPassThroughEdge(o));
if (!e || !user_to_source.contains(e)) continue;
user_to_source[o] = user_to_source[e];
user_to_source.emplace(std::make_pair(o, user_to_source[e]));
}
}
for (auto it : user_to_source) {
(*sources_paths)[it.second].emplace(it.first->dst());
ResourceUsageAnalysis::NodeInfo dst_node_info(
function_name, it.first->dst()->name(), it.first->dst()->type_string());
(*source_to_path)[it.second].emplace(dst_node_info);
}
return Status::OK();
}
} // anonymous namespace
/*Static*/ Status ResourceUsageAnalysis::Analyze(
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
source_to_path) {
return AnalyzeResourceUsage(
graph, lib_runtime, /*function_name=*/{}, /*call_depth=*/0,
/*resource_arg_indices=*/absl::flat_hash_set<int>(), source_to_path);
}
} // namespace tensorflow

View File

@ -20,38 +20,78 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/hash/hash.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace tensorflow {
// AnalyzeResourceOpSourcePath analyzes a Tensorflow graph and finds all
// operations that creates Stack/TensorArray resources and all the operations
// that consume resource created by them.
//
// Note that _Arg nodes that introduce resources are not considered sources.
// Note again that Control Flow v1 nodes (Enter/Exit/Switch/Merge/NextIteration)
// are not supported. Graphs contain these nodes cause analysis failures.
// However Control Flow v2 nodes (While/If) will be supported.
//
// TODO(b/135628319): Support analyzing function call and functional while/if
// as pass-through ops.
//
// For example, consider following subgraph:
//
// TensorArrayOp -> Identity -> TensorArrayWriteOp
//
// It should be able to tell that TensorArrayWriteOp actually operates on the
// resource created by TensorArrayOp even though there might be
// non-resource-specific operations like Identity (or other pass-through
// operations).
//
// sources_paths maps the nodes that creates resources to all nodes that operate
// on corresponding resource, not including sources themselves. It is cleared
// upon calling this method.
Status AnalyzeResourceOpSourcePath(
const Graph* graph,
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>*
sources_paths);
class ResourceUsageAnalysis {
public:
// NodeInfo is a triple of function_name:node_name:op to uniquely identity a
// node in graph. ResourceUsageAnalysis uses it to represent resource sources
// and users.
class NodeInfo {
public:
absl::optional<std::string> function_name_;
std::string node_name_;
std::string op_;
NodeInfo() {}
NodeInfo(const absl::optional<std::string>& function_name,
std::string node_name, std::string op)
: function_name_(function_name),
node_name_(std::move(node_name)),
op_(std::move(op)) {}
std::string DebugString() const {
return absl::StrJoin({function_name_.value_or(""), node_name_, op_}, ":");
}
bool operator==(const NodeInfo& o) const {
return function_name_ == o.function_name_ && node_name_ == o.node_name_ &&
op_ == o.op_;
}
template <typename H>
friend H AbslHashValue(H h, const NodeInfo& o) {
return H::combine(std::move(h), o.function_name_, o.node_name_, o.op_);
}
};
// This method analyzes a Tensorflow graph and finds all operations that
// create Stack/TensorArray resources and all the operations that consume
// resource created by them.
//
// Note that _Arg nodes that introduce resources are not considered sources.
// Note again that Control Flow v1 nodes
// (Enter/Exit/Switch/Merge/NextIteration) are not supported. Graphs contain
// these nodes cause analysis failures. However Control Flow v2 nodes
// (While/If) will be supported.
//
// TODO(b/135628319): Support analyzing functional while/if as pass-through
// ops.
//
// For example, consider following subgraph:
//
// TensorArrayOp -> Identity -> TensorArrayWriteOp
//
// It should be able to tell that TensorArrayWriteOp actually operates on the
// resource created by TensorArrayOp even though there might be
// non-resource-specific operations like Identity (or other pass-through
// operations).
//
// source_to_path maps the nodes that creates resources to all nodes that
// operate on the corresponding resource, not including sources themselves. It
// is cleared upon calling this method.
static Status Analyze(
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
source_to_path);
};
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_

View File

@ -26,33 +26,57 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
ResourceUsageAnalysis::NodeInfo node_info_from_string(absl::string_view s) {
std::vector<std::string> tokens = absl::StrSplit(s, ':');
EXPECT_EQ(tokens.size(), 3);
ResourceUsageAnalysis::NodeInfo node_info;
if (tokens[0].empty()) {
node_info.function_name_ = absl::nullopt;
} else {
node_info.function_name_ = std::move(tokens[0]);
}
node_info.node_name_ = std::move(tokens[1]);
node_info.op_ = std::move(tokens[2]);
return node_info;
}
void AnalyzeAndVerify(
const GraphDef& graphdef,
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>*
const GraphDef& graphdef, FunctionLibraryDefinition* flib_def,
const absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>&
expected) {
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
auto graph = absl::make_unique<Graph>(flib_def);
TF_EXPECT_OK(
ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get()));
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>
sources_paths;
TF_EXPECT_OK(AnalyzeResourceOpSourcePath(graph.get(), &sources_paths));
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def,
OptimizerOptions());
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
source_to_path;
TF_EXPECT_OK(ResourceUsageAnalysis::Analyze(graph.get(), lib_runtime,
&source_to_path));
EXPECT_EQ(sources_paths.size(), expected->size());
for (const auto it : sources_paths) {
const std::string& src_name = it.first->name();
const auto& expected_path = expected->at(src_name);
EXPECT_EQ(it.second.size(), expected_path.size());
for (const Node* n : it.second) {
EXPECT_TRUE(expected_path.find(n->name()) != expected_path.end());
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
expected_source_to_path;
for (auto it : expected) {
auto src_node_info = node_info_from_string(it.first);
for (const std::string& user : it.second) {
expected_source_to_path[src_node_info].emplace(
node_info_from_string(user));
}
}
EXPECT_EQ(source_to_path, expected_source_to_path);
}
} // anonymous namespace
@ -87,8 +111,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserNoPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op"] = absl::flat_hash_set<std::string>({"stack_close"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op:StackV2"] =
absl::flat_hash_set<std::string>({":stack_close:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserWithPassThrough) {
@ -126,9 +151,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserWithPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op"] =
absl::flat_hash_set<std::string>({"resource_identity", "stack_close"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
{":resource_identity:Identity", ":stack_close:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserNoPassThrough) {
@ -169,9 +194,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserNoPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op"] =
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserWithPassThrough) {
@ -217,9 +242,10 @@ TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserWithPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op"] = absl::flat_hash_set<std::string>(
{"resource_identity", "stack_close0", "stack_close1"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
{":resource_identity:Identity", ":stack_close0:StackCloseV2",
":stack_close1:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserNoPassThrough) {
@ -278,11 +304,11 @@ TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserNoPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op0"] =
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
expected["stack_op1"] =
absl::flat_hash_set<std::string>({"stack_close2", "stack_close3"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op0:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
expected[":stack_op1:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close2:StackCloseV2", ":stack_close3:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserWithPassThrough) {
@ -341,11 +367,168 @@ TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserWithPassThrough) {
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["stack_op0"] =
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
expected["stack_op1"] =
absl::flat_hash_set<std::string>({"stack_close2", "stack_close3"});
AnalyzeAndVerify(graphdef, &expected);
expected[":stack_op0:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
expected[":stack_op1:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close2:StackCloseV2", ":stack_close3:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, ResourcePassThroughFunction) {
auto library = absl::make_unique<FunctionDefLibrary>();
/*
* pass_through_function:
*
* _Arg -> Identity -> _Retval
*/
*library->add_function() = FunctionDefHelper::Define(
/*function_name=*/"pass_through_function",
/*arg_def=*/{"in: resource"},
/*ret_def=*/{"out: resource"},
/*attr_def=*/{},
/*node_def=*/
{{{"out"}, "Identity", {"in"}, {{"T", DataType::DT_RESOURCE}}}});
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
auto opts = builder.opts();
auto op_reg = opts.op_registry();
{
/*
* stack_size -> stack_op -> pass_through_function -> stack_close
*/
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
op_reg);
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
Node* stack_size_placeholder =
opts.FinalizeBuilder(&stack_size_placeholder_builder);
NodeBuilder stack_op_builder("stack_op", "StackV2", op_reg);
stack_op_builder.Input(stack_size_placeholder).Attr("elem_type", DT_FLOAT);
Node* stack_op = opts.FinalizeBuilder(&stack_op_builder);
NodeBuilder pass_through_fn_builder("pass_through_fn",
"pass_through_function", op_reg);
pass_through_fn_builder.Input(stack_op);
Node* pass_through_fn = opts.FinalizeBuilder(&pass_through_fn_builder);
NodeBuilder stack_close_builder("stack_close", "StackCloseV2", op_reg);
stack_close_builder.Input(pass_through_fn);
opts.FinalizeBuilder(&stack_close_builder);
}
GraphDef graphdef;
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
{":stack_close:StackCloseV2", ":pass_through_fn:pass_through_function",
"pass_through_function:out:Identity"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, ResourceUserInFunction) {
auto library = absl::make_unique<FunctionDefLibrary>();
/*
* resource_user_function:
*
* _Arg -> Identity -> StackCloseV2
*/
*library->add_function() = FunctionDefHelper::Define(
/*function_name=*/"resource_user_function",
/*arg_def=*/{"in: resource"},
/*ret_def=*/{},
/*attr_def=*/{},
/*node_def=*/
{{{"stack_close"},
"StackCloseV2",
{"in"},
{{"T", DataType::DT_RESOURCE}}}});
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
auto opts = builder.opts();
auto op_reg = opts.op_registry();
{
/*
* stack_size -> stack_op -> resource_user_function
*/
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
op_reg);
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
Node* stack_size_placeholder =
opts.FinalizeBuilder(&stack_size_placeholder_builder);
NodeBuilder stack_op_builder("stack_op", "StackV2", op_reg);
stack_op_builder.Input(stack_size_placeholder).Attr("elem_type", DT_FLOAT);
Node* stack_op = opts.FinalizeBuilder(&stack_op_builder);
NodeBuilder resource_user_fn_builder("resource_user_function",
"resource_user_function", op_reg);
resource_user_fn_builder.Input(stack_op);
opts.FinalizeBuilder(&resource_user_fn_builder);
}
GraphDef graphdef;
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
{":resource_user_function:resource_user_function",
"resource_user_function:stack_close:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
TEST(ResourceOpAnalyzerTest, ResourceSourceInFunction) {
auto library = absl::make_unique<FunctionDefLibrary>();
/*
* resource_source_function:
*
* _Arg -> StackV2 -> _Retval
*/
*library->add_function() = FunctionDefHelper::Define(
/*function_name=*/"resource_source_function",
/*arg_def=*/{"in: int32"},
/*ret_def=*/{"out: resource"},
/*attr_def=*/{},
/*node_def=*/
{{{"out"}, "StackV2", {"in"}, {{"elem_type", DataType::DT_FLOAT}}}});
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
auto opts = builder.opts();
auto op_reg = opts.op_registry();
{
/*
* stack_size -> resource_source_function -> stack_close
*/
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
op_reg);
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
Node* stack_size_placeholder =
opts.FinalizeBuilder(&stack_size_placeholder_builder);
NodeBuilder resource_source_fn_builder("resource_source_function",
"resource_source_function", op_reg);
resource_source_fn_builder.Input(stack_size_placeholder);
Node* resource_source_function =
opts.FinalizeBuilder(&resource_source_fn_builder);
NodeBuilder stack_close_builder("stack_close", "StackCloseV2", op_reg);
stack_close_builder.Input(resource_source_function);
opts.FinalizeBuilder(&stack_close_builder);
}
GraphDef graphdef;
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
expected["resource_source_function:out:StackV2"] =
absl::flat_hash_set<std::string>({":stack_close:StackCloseV2"});
AnalyzeAndVerify(graphdef, &flib_def, expected);
}
} // namespace tensorflow