Utility function to figure out whether a cluster has reference variables

PiperOrigin-RevId: 270145387
This commit is contained in:
George Karpenkov 2019-09-19 16:01:04 -07:00 committed by TensorFlower Gardener
parent 8c521f81b1
commit 7e06663468
4 changed files with 354 additions and 2 deletions

View File

@ -632,10 +632,12 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/gtl:cleanup",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@ -789,12 +791,14 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/cc:ops",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",

View File

@ -18,17 +18,18 @@ limitations under the License.
#include <unordered_map>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/public/session_options.h"
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/xla_config_registry.h"
@ -386,6 +387,190 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) {
return result;
}
namespace {
using CallTargetListTy = absl::InlinedVector<NameAttrList, 2>;
CallTargetListTy GetCallTargetListFromNode(
const Node& n, FunctionLibraryRuntime* lib_runtime) {
const FunctionLibraryDefinition& flib_def =
*lib_runtime->GetFunctionLibraryDefinition();
if (flib_def.Find(n.type_string())) {
NameAttrList callee;
callee.set_name(n.type_string());
*callee.mutable_attr() = n.def().attr();
return {callee};
}
CallTargetListTy result;
for (const auto& name_attr_pair : n.attrs()) {
const AttrValue& attr_value = name_attr_pair.second;
if (attr_value.value_case() == AttrValue::kFunc) {
result.push_back(attr_value.func());
} else if (attr_value.value_case() == AttrValue::kList) {
result.insert(result.end(), attr_value.list().func().begin(),
attr_value.list().func().end());
}
}
return result;
}
enum class Direction { kForward, kBackward };
Status GetNodesRelatedToRefVariablesInDirection(
const Graph& graph, FunctionLibraryRuntime* lib_runtime,
Direction direction, int depth, absl::flat_hash_set<Node*>* result);
xla::StatusOr<bool> DoesAnyCalleeHaveRefNodes(
const CallTargetListTy& call_target_list,
FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) {
const int kMaxDepth = 10;
if (depth == kMaxDepth && !call_target_list.empty()) {
// Conservative answer to avoid recursing too much.
return true;
}
absl::flat_hash_set<Node*> callee_ref_nodes;
for (const NameAttrList& call_target : call_target_list) {
const OpRegistrationData* op_reg;
if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) {
const OpDef& op = op_reg->op_def;
if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) {
return arg.is_ref();
})) {
return true;
}
continue;
}
callee_ref_nodes.clear();
FunctionLibraryRuntime::Handle handle;
if (!lib_runtime
->Instantiate(call_target.name(), AttrSlice(&call_target.attr()),
&handle)
.ok()) {
VLOG(2) << "Could not find " << call_target.name()
<< " in the function library.";
// Since we don't know the semantic of `n` we don't know if this is an
// error. We return true to signal a conservative answer.
return true;
}
auto release_handle_on_return = gtl::MakeCleanup(
[&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); });
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
*fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes));
// We could possibly use something cheaper than
// GetNodesRelatedToRefVariablesInDirection since we only care about the
// size of `callee_ref_nodes` but for now we don't ceare.
if (!callee_ref_nodes.empty()) {
return true;
}
}
return false;
}
// Helper for GetNodesRelatedToRefVariables that traverses the graph in one
// direction.
Status GetNodesRelatedToRefVariablesInDirection(
const Graph& graph, FunctionLibraryRuntime* lib_runtime,
Direction direction, int depth, absl::flat_hash_set<Node*>* result) {
std::vector<Node*> nodes_in_order;
if (direction == Direction::kForward) {
GetReversePostOrder(graph, &nodes_in_order,
/*stable_comparator=*/NodeComparatorName());
} else {
GetPostOrder(graph, &nodes_in_order,
/*stable_comparator=*/NodeComparatorName());
}
int old_result_size;
int iterations = 0;
const int kMaxIterations = 10 * 1000;
std::vector<bool> callee_has_ref_nodes_cache;
callee_has_ref_nodes_cache.resize(graph.num_node_ids());
auto does_callee_have_ref_nodes = [&](Node* n) -> xla::StatusOr<bool> {
if (iterations == 1) {
TF_ASSIGN_OR_RETURN(
bool callee_has_ref_nodes,
DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime),
lib_runtime, direction, depth));
callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes;
return callee_has_ref_nodes;
} else {
return {callee_has_ref_nodes_cache[n->id()]};
}
};
do {
TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?";
old_result_size = result->size();
for (Node* n : nodes_in_order) {
if (n->IsSource() || n->IsSink()) {
continue;
}
bool inserted_n = false;
const EdgeSet& edges =
direction == Direction::kForward ? n->in_edges() : n->out_edges();
for (const Edge* e : edges) {
if (result->contains(direction == Direction::kForward ? e->src()
: e->dst())) {
result->insert(n);
inserted_n = true;
break;
}
}
if (inserted_n) {
continue;
}
if (direction == Direction::kForward &&
absl::c_any_of(n->output_types(), IsRefType)) {
result->insert(n);
continue;
}
TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes,
does_callee_have_ref_nodes(n));
if (callee_has_ref_nodes) {
result->insert(n);
continue;
}
}
// Loop until convergence.
} while (result->size() != old_result_size);
VLOG(2) << "# iterations = " << iterations;
return Status::OK();
}
} // namespace
xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
const Graph& graph, FunctionLibraryRuntime* lib_runtime) {
absl::flat_hash_set<Node*> result;
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
graph, lib_runtime, Direction::kForward, 0, &result));
TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection(
graph, lib_runtime, Direction::kBackward, 0, &result));
VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size()
<< " nodes";
return result;
}
// Register a callback for querying XlaGlobalJitLevel.
REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel);

View File

@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/xla_activity.pb.h"
@ -94,6 +95,14 @@ bool IsShapeConsumerOp(const Node& node);
// `XlaAutoClusteringSummary` for details.
XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph);
// Returns the set of nodes that have a path to or from nodes that may have ref
// variables as input or output.
//
// We assume each node has a trivial path to itself so the returned set includes
// all of the nodes that have ref variables as input or output.
xla::StatusOr<absl::flat_hash_set<Node*>> GetNodesRelatedToRefVariables(
const Graph& graph, FunctionLibraryRuntime* lib_runtime);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_

View File

@ -19,8 +19,11 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/algorithm.h"
@ -29,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@ -130,5 +134,155 @@ TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) {
EXPECT_FALSE(IsSingleGpuGraph(*root.graph()));
}
xla::StatusOr<std::vector<string>> GetNodesRelatedToRefVarsSorted(
const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) {
FunctionDefLibrary flib;
FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib);
if (flib_def == nullptr) {
flib_def = &flib_def_local;
}
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, Env::Default(),
TF_GRAPH_DEF_VERSION, flib_def,
OptimizerOptions{}));
FunctionLibraryRuntime* lib_runtime =
pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
TF_ASSIGN_OR_RETURN(absl::flat_hash_set<Node*> nodes_related_to_ref_vars,
GetNodesRelatedToRefVariables(*graph, lib_runtime));
std::vector<string> names;
absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names),
[](Node* n) { return n->name(); });
absl::c_sort(names);
return names;
}
void CreateSubgraphTouchingRefVar(const Scope& s) {
Output variable =
ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT);
Output read = ops::Identity(s.WithOpName("read_ref_var"), variable);
Output neg = ops::Negate(s.WithOpName("negate_ref"), read);
Output add = ops::Add(s.WithOpName("add_ref"), neg, neg);
Output constant =
ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0));
s.graph()->AddControlEdge(constant.node(), variable.node());
}
void CreateSubgraphNotTouchingRefVar(const Scope& s) {
Output constant =
ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0));
Output neg = ops::Negate(s.WithOpName("negate_normal"), constant);
Output add = ops::Add(s.WithOpName("add_normal"), neg, neg);
}
void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) {
NameAttrList ref_float_function;
ref_float_function.set_name("RefFloatFn");
ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span<Input>{}},
{DT_FLOAT}, ref_float_function);
Output constant =
ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0));
s.graph()->AddControlEdge(call.operation.node(), constant.node());
}
void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) {
NameAttrList regular_float_function;
regular_float_function.set_name("RegularFloatFn");
ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span<Input>{}},
{DT_FLOAT}, regular_float_function);
Output constant =
ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0));
s.graph()->AddControlEdge(call.operation.node(), constant.node());
}
void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
FunctionDef make_ref_float = FunctionDefHelper::Define(
"RefFloatFn", {}, {"r:float"}, {},
{{{"var"},
"VariableV2",
{},
{{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}},
{{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}});
*fdef_lib->add_function() = make_ref_float;
}
void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) {
Tensor seven(DT_FLOAT, {});
seven.scalar<float>()() = 7;
FunctionDef make_regular_float = FunctionDefHelper::Define(
"RegularFloatFn", {}, {"r:float"}, {},
{{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}});
*fdef_lib->add_function() = make_regular_float;
}
TEST(NodesRelatedToRefVariables, Basic) {
Scope root = Scope::NewRootScope().ExitOnError();
FunctionDefLibrary fdef_lib;
CreateSubgraphTouchingRefVar(root);
CreateSubgraphNotTouchingRefVar(root);
AddRefFunctionFunctionDef(&fdef_lib);
CreateSubgraphCallingFunctionWithRefVar(root);
AddRegularFunctionFunctionDef(&fdef_lib);
CreateSubgraphCallingFunctionWithoutRefVar(root);
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
GetNodesRelatedToRefVarsSorted(root, &flib_def));
std::vector<string> expected({
"RefFloat",
"add_ref",
"constant_ref",
"constant_ref_pco",
"negate_ref",
"read_ref_var",
"variable",
});
EXPECT_EQ(names, expected);
}
Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) {
s = s.NewSubScope(std::string(loop_name));
ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name);
ops::Merge merge(s.WithOpName("merge"), {init_value, init_value});
Output next_iteration =
ops::NextIteration(s.WithOpName("next_itr"), merge.output);
return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(),
1);
}
TEST(NodesRelatedToRefVariables, Cycles) {
Scope root = Scope::NewRootScope().ExitOnError();
Output variable = ops::Variable(root.WithOpName("variable"),
PartialTensorShape{}, DT_FLOAT);
TF_ASSERT_OK(
MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable),
"ref_loop"));
TF_ASSERT_OK(MakeLoop(
root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)),
"normal_loop"));
TF_ASSERT_OK_AND_ASSIGN(std::vector<string> names,
GetNodesRelatedToRefVarsSorted(root));
std::vector<string> expected({"read_ref_var", "ref_loop/init_value",
"ref_loop/merge", "ref_loop/next_itr",
"variable"});
EXPECT_EQ(names, expected);
}
} // namespace
} // namespace tensorflow