Add function call support in resource op analysis
PiperOrigin-RevId: 257099969
This commit is contained in:
parent
bdfa6ed2d2
commit
2036c4ad02
@ -737,12 +737,16 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":resource_operation_table",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -753,6 +757,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":resource_util",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
|
@ -21,10 +21,12 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/error_codes.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -35,6 +37,10 @@ using stream_executor::port::StatusOr;
|
||||
const char kIdentityNOp[] = "IdentityN";
|
||||
const char kIfOp[] = "If";
|
||||
const char kWhileOp[] = "While";
|
||||
const char kArgOp[] = "_Arg";
|
||||
const char kRetvalOp[] = "_Retval";
|
||||
|
||||
const int kMaxCallDepth = 100;
|
||||
|
||||
bool IsControlFlowV1Node(const Node* n) {
|
||||
return (n->IsEnter() || n->IsExit() || n->IsSwitch() || n->IsMerge() ||
|
||||
@ -58,14 +64,27 @@ StatusOr<const Edge*> WalkBackPassThroughEdge(const Edge* e) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
// TODO(ycao): Support pass-through function calls and functional while/if
|
||||
// nodes.
|
||||
|
||||
// Reaching here means e is not coming from a pass through node, return empty
|
||||
// vector to indicate we can no longer trace back.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// TODO(ycao): Add this as Tensorflow Node method.
|
||||
StatusOr<absl::InlinedVector<const Edge*, 1>> OutputEdgesByIndex(const Node* n,
|
||||
int idx) {
|
||||
absl::InlinedVector<const Edge*, 1> res;
|
||||
if (idx >= n->num_outputs()) {
|
||||
return errors::InvalidArgument("Invalid out_edge index: ", idx, ", Node ",
|
||||
n->name(), " only has ", n->num_outputs(),
|
||||
" outputs.");
|
||||
}
|
||||
|
||||
for (const Edge* o : n->out_edges()) {
|
||||
if (o->src_output() == idx) res.emplace_back(o);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
bool IsStackOrTensorArraySource(const Node* n) {
|
||||
const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n->type_string());
|
||||
|
||||
@ -76,41 +95,143 @@ bool IsStackOrTensorArraySource(const Node* n) {
|
||||
return n->num_outputs() > 0 && n->output_type(0) == DataType::DT_RESOURCE;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Status AnalyzeResourceOpSourcePath(
|
||||
const Graph* graph,
|
||||
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>*
|
||||
sources_paths) {
|
||||
sources_paths->clear();
|
||||
Status AnalyzeResourceUsage(
|
||||
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
|
||||
const absl::optional<std::string>& function_name, const int call_depth,
|
||||
const absl::flat_hash_set<int>& resource_arg_indices,
|
||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>*
|
||||
source_to_path) {
|
||||
source_to_path->clear();
|
||||
|
||||
std::vector<Node*> reverse_post_order;
|
||||
GetReversePostOrder(*graph, &reverse_post_order, NodeComparatorName{});
|
||||
|
||||
// user_to_source maps from an edge carrying a Stack or TensorArray resource
|
||||
// to the node that created this resource.
|
||||
absl::flat_hash_map<const Edge*, const Node*> user_to_source;
|
||||
absl::flat_hash_map<const Edge*, ResourceUsageAnalysis::NodeInfo>
|
||||
user_to_source;
|
||||
for (const Node* n : reverse_post_order) {
|
||||
if (IsControlFlowV1Node(n)) {
|
||||
return errors::InvalidArgument(
|
||||
"AnalyzeResourceOpSourcePath does not support control flow v1 node: ",
|
||||
"AnalyzeResourceUsage does not support control flow v1 node: ",
|
||||
n->DebugString());
|
||||
}
|
||||
|
||||
// TODO(ycao): Support pass-through functional while/if nodes.
|
||||
if (n->type_string() == kIfOp || n->type_string() == kWhileOp) {
|
||||
return errors::InvalidArgument(
|
||||
"AnalyzeResourceOpSourcePath does not yet support control flow v2 "
|
||||
"AnalyzeResourceUsage does not yet support control flow v2 "
|
||||
"node: ",
|
||||
n->DebugString());
|
||||
}
|
||||
|
||||
// Record a resource source edge.
|
||||
if (IsStackOrTensorArraySource(n)) {
|
||||
ResourceUsageAnalysis::NodeInfo src_node_info(function_name, n->name(),
|
||||
n->type_string());
|
||||
for (const Edge* o : n->out_edges()) {
|
||||
if (o->IsControlEdge()) continue;
|
||||
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE)
|
||||
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
||||
continue;
|
||||
user_to_source[o] = n;
|
||||
}
|
||||
user_to_source[o] = src_node_info;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Arguments that are listed in resource_arg_indices are also considered as
|
||||
// resource sources.
|
||||
if (n->IsArg()) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
if (!resource_arg_indices.contains(index)) continue;
|
||||
|
||||
TF_RET_CHECK(function_name.has_value())
|
||||
<< "ResourceUsageAnalysis does not support analyzing _Arg nodes "
|
||||
"carrying Stack/TensorArray resource in given graph unless they "
|
||||
"are in function calls.";
|
||||
|
||||
const ResourceUsageAnalysis::NodeInfo src_node_info(
|
||||
function_name, n->name(), n->type_string());
|
||||
|
||||
for (const Edge* o : n->out_edges()) {
|
||||
if (o->IsControlEdge()) continue;
|
||||
if (o->dst()->input_type(o->dst_input()) != DataType::DT_RESOURCE) {
|
||||
continue;
|
||||
}
|
||||
user_to_source[o] = src_node_info;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsFunctionCall(*lib_runtime->GetFunctionLibraryDefinition(), *n)) {
|
||||
if (call_depth > kMaxCallDepth) {
|
||||
return errors::InvalidArgument(
|
||||
"Function call stack in given graph is too deep, last function ",
|
||||
"name is: ", function_name.value());
|
||||
}
|
||||
// resource_arg_indices_for_call contains all indices of the input
|
||||
// arguments that carry Stack/TensorArray resource handles.
|
||||
absl::flat_hash_set<int> resource_arg_indices_for_call;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (!user_to_source.contains(e)) continue;
|
||||
resource_arg_indices_for_call.emplace(e->dst_input());
|
||||
}
|
||||
|
||||
absl::string_view called_function_name = n->type_string();
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
InstantiateFunctionCall(n->def(), lib_runtime, &handle));
|
||||
auto release_handle_on_return = gtl::MakeCleanup(
|
||||
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
|
||||
// Recursively analyze called function for resource sources and users.
|
||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
|
||||
called_function_source_to_path;
|
||||
TF_RETURN_IF_ERROR(AnalyzeResourceUsage(
|
||||
fbody->graph, lib_runtime,
|
||||
absl::optional<std::string>(called_function_name), call_depth + 1,
|
||||
resource_arg_indices_for_call, &called_function_source_to_path));
|
||||
|
||||
std::unordered_map<std::string, Node*> node_name_index =
|
||||
fbody->graph->BuildNodeNameIndex();
|
||||
|
||||
for (auto it : called_function_source_to_path) {
|
||||
ResourceUsageAnalysis::NodeInfo src_node_info = it.first;
|
||||
|
||||
// If source is an _Arg, then the true source is actually corresponding
|
||||
// edge that feeds into function call node with the same index.
|
||||
if (src_node_info.op_ == kArgOp) {
|
||||
const Node* arg_src = node_name_index[src_node_info.node_name_];
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(arg_src->attrs(), "index", &index));
|
||||
|
||||
const Edge* e;
|
||||
TF_RETURN_IF_ERROR(n->input_edge(index, &e));
|
||||
const Node* true_src = e->src();
|
||||
src_node_info.function_name_ = function_name;
|
||||
src_node_info.node_name_ = true_src->name();
|
||||
src_node_info.op_ = true_src->type_string();
|
||||
}
|
||||
|
||||
for (const auto& dst_node_info : it.second) {
|
||||
// If user is an _Retval, then the true user is actually corresponding
|
||||
// edge of that _Retval.
|
||||
if (dst_node_info.op_ == kRetvalOp) {
|
||||
const Node* ret_user = node_name_index[dst_node_info.node_name_];
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(ret_user->attrs(), "index", &index));
|
||||
|
||||
absl::InlinedVector<const Edge*, 1> outs;
|
||||
TF_ASSIGN_OR_RETURN(outs, OutputEdgesByIndex(n, index));
|
||||
for (const Edge* o : outs) user_to_source[o] = src_node_info;
|
||||
} else {
|
||||
(*source_to_path)[src_node_info].emplace(dst_node_info);
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -119,15 +240,28 @@ Status AnalyzeResourceOpSourcePath(
|
||||
if (o->IsControlEdge()) continue;
|
||||
TF_ASSIGN_OR_RETURN(const Edge* e, WalkBackPassThroughEdge(o));
|
||||
if (!e || !user_to_source.contains(e)) continue;
|
||||
user_to_source[o] = user_to_source[e];
|
||||
user_to_source.emplace(std::make_pair(o, user_to_source[e]));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto it : user_to_source) {
|
||||
(*sources_paths)[it.second].emplace(it.first->dst());
|
||||
ResourceUsageAnalysis::NodeInfo dst_node_info(
|
||||
function_name, it.first->dst()->name(), it.first->dst()->type_string());
|
||||
(*source_to_path)[it.second].emplace(dst_node_info);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
/*Static*/ Status ResourceUsageAnalysis::Analyze(
|
||||
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
|
||||
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
|
||||
source_to_path) {
|
||||
return AnalyzeResourceUsage(
|
||||
graph, lib_runtime, /*function_name=*/{}, /*call_depth=*/0,
|
||||
/*resource_arg_indices=*/absl::flat_hash_set<int>(), source_to_path);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,38 +20,78 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/hash/hash.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// AnalyzeResourceOpSourcePath analyzes a Tensorflow graph and finds all
|
||||
// operations that creates Stack/TensorArray resources and all the operations
|
||||
// that consume resource created by them.
|
||||
//
|
||||
// Note that _Arg nodes that introduce resources are not considered sources.
|
||||
// Note again that Control Flow v1 nodes (Enter/Exit/Switch/Merge/NextIteration)
|
||||
// are not supported. Graphs contain these nodes cause analysis failures.
|
||||
// However Control Flow v2 nodes (While/If) will be supported.
|
||||
//
|
||||
// TODO(b/135628319): Support analyzing function call and functional while/if
|
||||
// as pass-through ops.
|
||||
//
|
||||
// For example, consider following subgraph:
|
||||
//
|
||||
// TensorArrayOp -> Identity -> TensorArrayWriteOp
|
||||
//
|
||||
// It should be able to tell that TensorArrayWriteOp actually operates on the
|
||||
// resource created by TensorArrayOp even though there might be
|
||||
// non-resource-specific operations like Identity (or other pass-through
|
||||
// operations).
|
||||
//
|
||||
// sources_paths maps the nodes that creates resources to all nodes that operate
|
||||
// on corresponding resource, not including sources themselves. It is cleared
|
||||
// upon calling this method.
|
||||
Status AnalyzeResourceOpSourcePath(
|
||||
const Graph* graph,
|
||||
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>*
|
||||
sources_paths);
|
||||
class ResourceUsageAnalysis {
|
||||
public:
|
||||
// NodeInfo is a triple of function_name:node_name:op to uniquely identity a
|
||||
// node in graph. ResourceUsageAnalysis uses it to represent resource sources
|
||||
// and users.
|
||||
class NodeInfo {
|
||||
public:
|
||||
absl::optional<std::string> function_name_;
|
||||
std::string node_name_;
|
||||
std::string op_;
|
||||
|
||||
NodeInfo() {}
|
||||
|
||||
NodeInfo(const absl::optional<std::string>& function_name,
|
||||
std::string node_name, std::string op)
|
||||
: function_name_(function_name),
|
||||
node_name_(std::move(node_name)),
|
||||
op_(std::move(op)) {}
|
||||
|
||||
std::string DebugString() const {
|
||||
return absl::StrJoin({function_name_.value_or(""), node_name_, op_}, ":");
|
||||
}
|
||||
|
||||
bool operator==(const NodeInfo& o) const {
|
||||
return function_name_ == o.function_name_ && node_name_ == o.node_name_ &&
|
||||
op_ == o.op_;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const NodeInfo& o) {
|
||||
return H::combine(std::move(h), o.function_name_, o.node_name_, o.op_);
|
||||
}
|
||||
};
|
||||
|
||||
// This method analyzes a Tensorflow graph and finds all operations that
|
||||
// create Stack/TensorArray resources and all the operations that consume
|
||||
// resource created by them.
|
||||
//
|
||||
// Note that _Arg nodes that introduce resources are not considered sources.
|
||||
// Note again that Control Flow v1 nodes
|
||||
// (Enter/Exit/Switch/Merge/NextIteration) are not supported. Graphs contain
|
||||
// these nodes cause analysis failures. However Control Flow v2 nodes
|
||||
// (While/If) will be supported.
|
||||
//
|
||||
// TODO(b/135628319): Support analyzing functional while/if as pass-through
|
||||
// ops.
|
||||
//
|
||||
// For example, consider following subgraph:
|
||||
//
|
||||
// TensorArrayOp -> Identity -> TensorArrayWriteOp
|
||||
//
|
||||
// It should be able to tell that TensorArrayWriteOp actually operates on the
|
||||
// resource created by TensorArrayOp even though there might be
|
||||
// non-resource-specific operations like Identity (or other pass-through
|
||||
// operations).
|
||||
//
|
||||
// source_to_path maps the nodes that creates resources to all nodes that
|
||||
// operate on the corresponding resource, not including sources themselves. It
|
||||
// is cleared upon calling this method.
|
||||
static Status Analyze(
|
||||
const Graph* graph, FunctionLibraryRuntime* lib_runtime,
|
||||
absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
|
||||
source_to_path);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_
|
||||
|
@ -26,33 +26,57 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
ResourceUsageAnalysis::NodeInfo node_info_from_string(absl::string_view s) {
|
||||
std::vector<std::string> tokens = absl::StrSplit(s, ':');
|
||||
EXPECT_EQ(tokens.size(), 3);
|
||||
|
||||
ResourceUsageAnalysis::NodeInfo node_info;
|
||||
if (tokens[0].empty()) {
|
||||
node_info.function_name_ = absl::nullopt;
|
||||
} else {
|
||||
node_info.function_name_ = std::move(tokens[0]);
|
||||
}
|
||||
node_info.node_name_ = std::move(tokens[1]);
|
||||
node_info.op_ = std::move(tokens[2]);
|
||||
return node_info;
|
||||
}
|
||||
|
||||
void AnalyzeAndVerify(
|
||||
const GraphDef& graphdef,
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>*
|
||||
const GraphDef& graphdef, FunctionLibraryDefinition* flib_def,
|
||||
const absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>&
|
||||
expected) {
|
||||
auto graph = absl::make_unique<Graph>(OpRegistry::Global());
|
||||
auto graph = absl::make_unique<Graph>(flib_def);
|
||||
TF_EXPECT_OK(
|
||||
ConvertGraphDefToGraph(GraphConstructorOptions(), graphdef, graph.get()));
|
||||
|
||||
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>
|
||||
sources_paths;
|
||||
TF_EXPECT_OK(AnalyzeResourceOpSourcePath(graph.get(), &sources_paths));
|
||||
auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
|
||||
nullptr, Env::Default(), TF_GRAPH_DEF_VERSION, flib_def,
|
||||
OptimizerOptions());
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
|
||||
source_to_path;
|
||||
TF_EXPECT_OK(ResourceUsageAnalysis::Analyze(graph.get(), lib_runtime,
|
||||
&source_to_path));
|
||||
|
||||
EXPECT_EQ(sources_paths.size(), expected->size());
|
||||
|
||||
for (const auto it : sources_paths) {
|
||||
const std::string& src_name = it.first->name();
|
||||
const auto& expected_path = expected->at(src_name);
|
||||
EXPECT_EQ(it.second.size(), expected_path.size());
|
||||
for (const Node* n : it.second) {
|
||||
EXPECT_TRUE(expected_path.find(n->name()) != expected_path.end());
|
||||
absl::flat_hash_map<ResourceUsageAnalysis::NodeInfo,
|
||||
absl::flat_hash_set<ResourceUsageAnalysis::NodeInfo>>
|
||||
expected_source_to_path;
|
||||
for (auto it : expected) {
|
||||
auto src_node_info = node_info_from_string(it.first);
|
||||
for (const std::string& user : it.second) {
|
||||
expected_source_to_path[src_node_info].emplace(
|
||||
node_info_from_string(user));
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(source_to_path, expected_source_to_path);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
@ -87,8 +111,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserNoPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op"] = absl::flat_hash_set<std::string>({"stack_close"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op:StackV2"] =
|
||||
absl::flat_hash_set<std::string>({":stack_close:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserWithPassThrough) {
|
||||
@ -126,9 +151,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceSingleUserWithPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op"] =
|
||||
absl::flat_hash_set<std::string>({"resource_identity", "stack_close"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":resource_identity:Identity", ":stack_close:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserNoPassThrough) {
|
||||
@ -169,9 +194,9 @@ TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserNoPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op"] =
|
||||
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserWithPassThrough) {
|
||||
@ -217,9 +242,10 @@ TEST(ResourceOpAnalyzerTest, SingleResourceMultipleUserWithPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op"] = absl::flat_hash_set<std::string>(
|
||||
{"resource_identity", "stack_close0", "stack_close1"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":resource_identity:Identity", ":stack_close0:StackCloseV2",
|
||||
":stack_close1:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserNoPassThrough) {
|
||||
@ -278,11 +304,11 @@ TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserNoPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op0"] =
|
||||
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
|
||||
expected["stack_op1"] =
|
||||
absl::flat_hash_set<std::string>({"stack_close2", "stack_close3"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op0:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
|
||||
expected[":stack_op1:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close2:StackCloseV2", ":stack_close3:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserWithPassThrough) {
|
||||
@ -341,11 +367,168 @@ TEST(ResourceOpAnalyzerTest, MultipleResourceMultipleUserWithPassThrough) {
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["stack_op0"] =
|
||||
absl::flat_hash_set<std::string>({"stack_close0", "stack_close1"});
|
||||
expected["stack_op1"] =
|
||||
absl::flat_hash_set<std::string>({"stack_close2", "stack_close3"});
|
||||
AnalyzeAndVerify(graphdef, &expected);
|
||||
expected[":stack_op0:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close0:StackCloseV2", ":stack_close1:StackCloseV2"});
|
||||
expected[":stack_op1:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close2:StackCloseV2", ":stack_close3:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, ResourcePassThroughFunction) {
|
||||
auto library = absl::make_unique<FunctionDefLibrary>();
|
||||
/*
|
||||
* pass_through_function:
|
||||
*
|
||||
* _Arg -> Identity -> _Retval
|
||||
*/
|
||||
*library->add_function() = FunctionDefHelper::Define(
|
||||
/*function_name=*/"pass_through_function",
|
||||
/*arg_def=*/{"in: resource"},
|
||||
/*ret_def=*/{"out: resource"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/
|
||||
{{{"out"}, "Identity", {"in"}, {{"T", DataType::DT_RESOURCE}}}});
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
auto opts = builder.opts();
|
||||
auto op_reg = opts.op_registry();
|
||||
|
||||
{
|
||||
/*
|
||||
* stack_size -> stack_op -> pass_through_function -> stack_close
|
||||
*/
|
||||
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
|
||||
op_reg);
|
||||
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
|
||||
Node* stack_size_placeholder =
|
||||
opts.FinalizeBuilder(&stack_size_placeholder_builder);
|
||||
|
||||
NodeBuilder stack_op_builder("stack_op", "StackV2", op_reg);
|
||||
stack_op_builder.Input(stack_size_placeholder).Attr("elem_type", DT_FLOAT);
|
||||
Node* stack_op = opts.FinalizeBuilder(&stack_op_builder);
|
||||
|
||||
NodeBuilder pass_through_fn_builder("pass_through_fn",
|
||||
"pass_through_function", op_reg);
|
||||
pass_through_fn_builder.Input(stack_op);
|
||||
Node* pass_through_fn = opts.FinalizeBuilder(&pass_through_fn_builder);
|
||||
|
||||
NodeBuilder stack_close_builder("stack_close", "StackCloseV2", op_reg);
|
||||
stack_close_builder.Input(pass_through_fn);
|
||||
opts.FinalizeBuilder(&stack_close_builder);
|
||||
}
|
||||
|
||||
GraphDef graphdef;
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":stack_close:StackCloseV2", ":pass_through_fn:pass_through_function",
|
||||
"pass_through_function:out:Identity"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, ResourceUserInFunction) {
|
||||
auto library = absl::make_unique<FunctionDefLibrary>();
|
||||
/*
|
||||
* resource_user_function:
|
||||
*
|
||||
* _Arg -> Identity -> StackCloseV2
|
||||
*/
|
||||
*library->add_function() = FunctionDefHelper::Define(
|
||||
/*function_name=*/"resource_user_function",
|
||||
/*arg_def=*/{"in: resource"},
|
||||
/*ret_def=*/{},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/
|
||||
{{{"stack_close"},
|
||||
"StackCloseV2",
|
||||
{"in"},
|
||||
{{"T", DataType::DT_RESOURCE}}}});
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
auto opts = builder.opts();
|
||||
auto op_reg = opts.op_registry();
|
||||
|
||||
{
|
||||
/*
|
||||
* stack_size -> stack_op -> resource_user_function
|
||||
*/
|
||||
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
|
||||
op_reg);
|
||||
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
|
||||
Node* stack_size_placeholder =
|
||||
opts.FinalizeBuilder(&stack_size_placeholder_builder);
|
||||
|
||||
NodeBuilder stack_op_builder("stack_op", "StackV2", op_reg);
|
||||
stack_op_builder.Input(stack_size_placeholder).Attr("elem_type", DT_FLOAT);
|
||||
Node* stack_op = opts.FinalizeBuilder(&stack_op_builder);
|
||||
|
||||
NodeBuilder resource_user_fn_builder("resource_user_function",
|
||||
"resource_user_function", op_reg);
|
||||
resource_user_fn_builder.Input(stack_op);
|
||||
opts.FinalizeBuilder(&resource_user_fn_builder);
|
||||
}
|
||||
|
||||
GraphDef graphdef;
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected[":stack_op:StackV2"] = absl::flat_hash_set<std::string>(
|
||||
{":resource_user_function:resource_user_function",
|
||||
"resource_user_function:stack_close:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
TEST(ResourceOpAnalyzerTest, ResourceSourceInFunction) {
|
||||
auto library = absl::make_unique<FunctionDefLibrary>();
|
||||
/*
|
||||
* resource_source_function:
|
||||
*
|
||||
* _Arg -> StackV2 -> _Retval
|
||||
*/
|
||||
*library->add_function() = FunctionDefHelper::Define(
|
||||
/*function_name=*/"resource_source_function",
|
||||
/*arg_def=*/{"in: int32"},
|
||||
/*ret_def=*/{"out: resource"},
|
||||
/*attr_def=*/{},
|
||||
/*node_def=*/
|
||||
{{{"out"}, "StackV2", {"in"}, {{"elem_type", DataType::DT_FLOAT}}}});
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), *library);
|
||||
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
|
||||
auto opts = builder.opts();
|
||||
auto op_reg = opts.op_registry();
|
||||
|
||||
{
|
||||
/*
|
||||
* stack_size -> resource_source_function -> stack_close
|
||||
*/
|
||||
NodeBuilder stack_size_placeholder_builder("stack_size", "Placeholder",
|
||||
op_reg);
|
||||
stack_size_placeholder_builder.Attr("dtype", DT_INT32);
|
||||
Node* stack_size_placeholder =
|
||||
opts.FinalizeBuilder(&stack_size_placeholder_builder);
|
||||
|
||||
NodeBuilder resource_source_fn_builder("resource_source_function",
|
||||
"resource_source_function", op_reg);
|
||||
resource_source_fn_builder.Input(stack_size_placeholder);
|
||||
Node* resource_source_function =
|
||||
opts.FinalizeBuilder(&resource_source_fn_builder);
|
||||
|
||||
NodeBuilder stack_close_builder("stack_close", "StackCloseV2", op_reg);
|
||||
stack_close_builder.Input(resource_source_function);
|
||||
opts.FinalizeBuilder(&stack_close_builder);
|
||||
}
|
||||
|
||||
GraphDef graphdef;
|
||||
TF_EXPECT_OK(builder.ToGraphDef(&graphdef));
|
||||
|
||||
absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> expected;
|
||||
expected["resource_source_function:out:StackV2"] =
|
||||
absl::flat_hash_set<std::string>({":stack_close:StackCloseV2"});
|
||||
AnalyzeAndVerify(graphdef, &flib_def, expected);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user