Utility function to figure out whether a cluster has reference variables
PiperOrigin-RevId: 270145387
This commit is contained in:
parent
8c521f81b1
commit
7e06663468
@ -632,10 +632,12 @@ cc_library(
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/gtl:cleanup",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
@ -789,12 +791,14 @@ tf_cc_test(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -18,17 +18,18 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/control_flow.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
#include "tensorflow/core/util/xla_config_registry.h"
|
||||
@ -386,6 +387,190 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
|
||||
return result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
|
||||
|
||||
CallTargetListTy GetCallTargetListFromNode(
|
||||
const Node& n, FunctionLibraryRuntime* lib_runtime) {
|
||||
const FunctionLibraryDefinition& flib_def =
|
||||
*lib_runtime->GetFunctionLibraryDefinition();
|
||||
if (flib_def.Find(n.type_string())) {
|
||||
NameAttrList callee;
|
||||
callee.set_name(n.type_string());
|
||||
*callee.mutable_attr() = n.def().attr();
|
||||
return {callee};
|
||||
}
|
||||
|
||||
CallTargetListTy result;
|
||||
for (const auto& name_attr_pair : n.attrs()) {
|
||||
const AttrValue& attr_value = name_attr_pair.second;
|
||||
if (attr_value.value_case() == AttrValue::kFunc) {
|
||||
result.push_back(attr_value.func());
|
||||
} else if (attr_value.value_case() == AttrValue::kList) {
|
||||
result.insert(result.end(), attr_value.list().func().begin(),
|
||||
attr_value.list().func().end());
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
enum class Direction { kForward, kBackward };
|
||||
|
||||
Status GetNodesRelatedToRefVariablesInDirection(
|
||||
const Graph& graph, FunctionLibraryRuntime* lib_runtime,
|
||||
Direction direction, int depth, absl::flat_hash_set<Node*>* result);
|
||||
|
||||
xla::StatusOr<bool> DoesAnyCalleeHaveRefNodes(
|
||||
const CallTargetListTy& call_target_list,
|
||||
FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
|
||||
const int kMaxDepth = 10;
|
||||
|
||||
if (depth == kMaxDepth && !call_target_list.empty()) {
|
||||
// Conservative answer to avoid recursing too much.
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<Node*> callee_ref_nodes;
|
||||
for (const NameAttrList& call_target : call_target_list) {
|
||||
const OpRegistrationData* op_reg;
|
||||
if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
|
||||
const OpDef& op = op_reg->op_def;
|
||||
if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
|
||||
return arg.is_ref();
|
||||
})) {
|
||||
return true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
callee_ref_nodes.clear();
|
||||
FunctionLibraryRuntime::Handle handle;
|
||||
if (!lib_runtime
|
||||
->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
|
||||
&handle)
|
||||
.ok()) {
|
||||
VLOG(2) << "Could not find " << call_target.name()
|
||||
<< " in the function library.";
|
||||
// Since we don't know the semantic of `n` we don't know if this is an
|
||||
// error. We return true to signal a conservative answer.
|
||||
return true;
|
||||
}
|
||||
|
||||
auto release_handle_on_return = gtl::MakeCleanup(
|
||||
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
|
||||
|
||||
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
|
||||
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
|
||||
*fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
|
||||
|
||||
// We could possibly use something cheaper than
|
||||
// GetNodesRelatedToRefVariablesInDirection since we only care about the
|
||||
// size of `callee_ref_nodes` but for now we don't ceare.
|
||||
if (!callee_ref_nodes.empty()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// Helper for GetNodesRelatedToRefVariables that traverses the graph in one
|
||||
// direction.
|
||||
Status GetNodesRelatedToRefVariablesInDirection(
|
||||
const Graph& graph, FunctionLibraryRuntime* lib_runtime,
|
||||
Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
|
||||
std::vector<Node*> nodes_in_order;
|
||||
if (direction == Direction::kForward) {
|
||||
GetReversePostOrder(graph, &nodes_in_order,
|
||||
/*stable_comparator=*/NodeComparatorName());
|
||||
} else {
|
||||
GetPostOrder(graph, &nodes_in_order,
|
||||
/*stable_comparator=*/NodeComparatorName());
|
||||
}
|
||||
|
||||
int old_result_size;
|
||||
int iterations = 0;
|
||||
|
||||
const int kMaxIterations = 10 * 1000;
|
||||
|
||||
std::vector<bool> callee_has_ref_nodes_cache;
|
||||
callee_has_ref_nodes_cache.resize(graph.num_node_ids());
|
||||
|
||||
auto does_callee_have_ref_nodes = [&](Node* n) -> xla::StatusOr<bool> {
|
||||
if (iterations == 1) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool callee_has_ref_nodes,
|
||||
DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
|
||||
lib_runtime, direction, depth));
|
||||
callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
|
||||
return callee_has_ref_nodes;
|
||||
} else {
|
||||
return {callee_has_ref_nodes_cache[n->id()]};
|
||||
}
|
||||
};
|
||||
|
||||
do {
|
||||
TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
|
||||
|
||||
old_result_size = result->size();
|
||||
for (Node* n : nodes_in_order) {
|
||||
if (n->IsSource() || n->IsSink()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool inserted_n = false;
|
||||
const EdgeSet& edges =
|
||||
direction == Direction::kForward ? n->in_edges() : n->out_edges();
|
||||
for (const Edge* e : edges) {
|
||||
if (result->contains(direction == Direction::kForward ? e->src()
|
||||
: e->dst())) {
|
||||
result->insert(n);
|
||||
inserted_n = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (inserted_n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (direction == Direction::kForward &&
|
||||
absl::c_any_of(n->output_types(), IsRefType)) {
|
||||
result->insert(n);
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
|
||||
does_callee_have_ref_nodes(n));
|
||||
if (callee_has_ref_nodes) {
|
||||
result->insert(n);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Loop until convergence.
|
||||
} while (result->size() != old_result_size);
|
||||
|
||||
VLOG(2) << "# iterations = " << iterations;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
|
||||
const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
|
||||
absl::flat_hash_set<Node*> result;
|
||||
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
|
||||
graph, lib_runtime, Direction::kForward, 0, &result));
|
||||
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
|
||||
graph, lib_runtime, Direction::kBackward, 0, &result));
|
||||
|
||||
VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
|
||||
<< " nodes";
|
||||
return result;
|
||||
}
|
||||
|
||||
// Register a callback for querying XlaGlobalJitLevel.
|
||||
REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);
|
||||
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
|
||||
#include "tensorflow/compiler/jit/xla_activity.pb.h"
|
||||
@ -94,6 +95,14 @@ bool IsShapeConsumerOp(const Node& node);
|
||||
// `XlaAutoClusteringSummary` for details.
|
||||
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
|
||||
|
||||
// Returns the set of nodes that have a path to or from nodes that may have ref
|
||||
// variables as input or output.
|
||||
//
|
||||
// We assume each node has a trivial path to itself so the returned set includes
|
||||
// all of the nodes that have ref variables as input or output.
|
||||
xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
|
||||
const Graph& graph, FunctionLibraryRuntime* lib_runtime);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
|
||||
|
@ -19,8 +19,11 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/function_ops.h"
|
||||
#include "tensorflow/cc/ops/functional_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/graph_to_functiondef.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
@ -29,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -130,5 +134,155 @@ TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) {
|
||||
|
||||
EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
|
||||
const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) {
|
||||
FunctionDefLibrary flib;
|
||||
FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib);
|
||||
if (flib_def == nullptr) {
|
||||
flib_def = &flib_def_local;
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
|
||||
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
|
||||
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
|
||||
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
|
||||
TF_GRAPH_DEF_VERSION, flib_def,
|
||||
OptimizerOptions{}));
|
||||
FunctionLibraryRuntime* lib_runtime =
|
||||
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> nodes_related_to_ref_vars,
|
||||
GetNodesRelatedToRefVariables(*graph, lib_runtime));
|
||||
|
||||
std::vector<string> names;
|
||||
absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names),
|
||||
[](Node* n) { return n->name(); });
|
||||
absl::c_sort(names);
|
||||
return names;
|
||||
}
|
||||
|
||||
void CreateSubgraphTouchingRefVar(const Scope& s) {
|
||||
Output variable =
|
||||
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
|
||||
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
|
||||
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
|
||||
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
|
||||
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
|
||||
s.graph()->AddControlEdge(constant.node(), variable.node());
|
||||
}
|
||||
|
||||
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
|
||||
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
|
||||
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
|
||||
}
|
||||
|
||||
void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) {
|
||||
NameAttrList ref_float_function;
|
||||
ref_float_function.set_name("RefFloatFn");
|
||||
ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span<Input>{}},
|
||||
{DT_FLOAT}, ref_float_function);
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0));
|
||||
s.graph()->AddControlEdge(call.operation.node(), constant.node());
|
||||
}
|
||||
|
||||
void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) {
|
||||
NameAttrList regular_float_function;
|
||||
regular_float_function.set_name("RegularFloatFn");
|
||||
ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span<Input>{}},
|
||||
{DT_FLOAT}, regular_float_function);
|
||||
Output constant =
|
||||
ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0));
|
||||
s.graph()->AddControlEdge(call.operation.node(), constant.node());
|
||||
}
|
||||
|
||||
void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
|
||||
FunctionDef make_ref_float = FunctionDefHelper::Define(
|
||||
"RefFloatFn", {}, {"r:float"}, {},
|
||||
{{{"var"},
|
||||
"VariableV2",
|
||||
{},
|
||||
{{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}},
|
||||
{{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}});
|
||||
*fdef_lib->add_function() = make_ref_float;
|
||||
}
|
||||
|
||||
void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
|
||||
Tensor seven(DT_FLOAT, {});
|
||||
seven.scalar<float>()() = 7;
|
||||
FunctionDef make_regular_float = FunctionDefHelper::Define(
|
||||
"RegularFloatFn", {}, {"r:float"}, {},
|
||||
{{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}});
|
||||
*fdef_lib->add_function() = make_regular_float;
|
||||
}
|
||||
|
||||
TEST(NodesRelatedToRefVariables, Basic) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
FunctionDefLibrary fdef_lib;
|
||||
|
||||
CreateSubgraphTouchingRefVar(root);
|
||||
CreateSubgraphNotTouchingRefVar(root);
|
||||
|
||||
AddRefFunctionFunctionDef(&fdef_lib);
|
||||
CreateSubgraphCallingFunctionWithRefVar(root);
|
||||
|
||||
AddRegularFunctionFunctionDef(&fdef_lib);
|
||||
CreateSubgraphCallingFunctionWithoutRefVar(root);
|
||||
|
||||
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
|
||||
GetNodesRelatedToRefVarsSorted(root, &flib_def));
|
||||
|
||||
std::vector<string> expected({
|
||||
"RefFloat",
|
||||
"add_ref",
|
||||
"constant_ref",
|
||||
"constant_ref_pco",
|
||||
"negate_ref",
|
||||
"read_ref_var",
|
||||
"variable",
|
||||
});
|
||||
|
||||
EXPECT_EQ(names, expected);
|
||||
}
|
||||
|
||||
Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) {
|
||||
s = s.NewSubScope(std::string(loop_name));
|
||||
ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name);
|
||||
ops::Merge merge(s.WithOpName("merge"), {init_value, init_value});
|
||||
Output next_iteration =
|
||||
ops::NextIteration(s.WithOpName("next_itr"), merge.output);
|
||||
return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(),
|
||||
1);
|
||||
}
|
||||
|
||||
TEST(NodesRelatedToRefVariables, Cycles) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
Output variable = ops::Variable(root.WithOpName("variable"),
|
||||
PartialTensorShape{}, DT_FLOAT);
|
||||
TF_ASSERT_OK(
|
||||
MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable),
|
||||
"ref_loop"));
|
||||
TF_ASSERT_OK(MakeLoop(
|
||||
root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)),
|
||||
"normal_loop"));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
|
||||
GetNodesRelatedToRefVarsSorted(root));
|
||||
std::vector<string> expected({"read_ref_var", "ref_loop/init_value",
|
||||
"ref_loop/merge", "ref_loop/next_itr",
|
||||
"variable"});
|
||||
|
||||
EXPECT_EQ(names, expected);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user