From 7e066634687493d2fb3226e63c13d2855438d679 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Thu, 19 Sep 2019 16:01:04 -0700 Subject: [PATCH] Utility function to figure out whether a cluster has reference variables PiperOrigin-RevId: 270145387 --- tensorflow/compiler/jit/BUILD | 4 + tensorflow/compiler/jit/xla_cluster_util.cc | 189 +++++++++++++++++- tensorflow/compiler/jit/xla_cluster_util.h | 9 + .../compiler/jit/xla_cluster_util_test.cc | 154 ++++++++++++++ 4 files changed, 354 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ecde0c3b270..a47a5c8b972 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -632,10 +632,12 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/lib/gtl:cleanup", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -789,12 +791,14 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", + "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 3863bcf3131..b8b11d2c7cd 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -18,17 +18,18 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/xla_config_registry.h" @@ -386,6 +387,190 @@ XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph) { return result; } +namespace { +using CallTargetListTy = absl::InlinedVector; + +CallTargetListTy GetCallTargetListFromNode( + const Node& n, FunctionLibraryRuntime* lib_runtime) { + const FunctionLibraryDefinition& flib_def = + *lib_runtime->GetFunctionLibraryDefinition(); + if (flib_def.Find(n.type_string())) { + NameAttrList callee; + callee.set_name(n.type_string()); + *callee.mutable_attr() = n.def().attr(); + return {callee}; + } + + CallTargetListTy result; + for (const auto& name_attr_pair : n.attrs()) { + const AttrValue& attr_value = name_attr_pair.second; + if (attr_value.value_case() == AttrValue::kFunc) { + result.push_back(attr_value.func()); + } else if (attr_value.value_case() == AttrValue::kList) { + result.insert(result.end(), attr_value.list().func().begin(), + attr_value.list().func().end()); + } + } + + return result; +} + +enum class Direction { kForward, kBackward }; + +Status GetNodesRelatedToRefVariablesInDirection( + const Graph& graph, FunctionLibraryRuntime* lib_runtime, + Direction direction, int depth, absl::flat_hash_set* result); + +xla::StatusOr DoesAnyCalleeHaveRefNodes( + const CallTargetListTy& call_target_list, + FunctionLibraryRuntime* lib_runtime, Direction direction, int depth) { + const int kMaxDepth = 10; + + if (depth == kMaxDepth && !call_target_list.empty()) { + // Conservative answer to avoid recursing too much. + return true; + } + + absl::flat_hash_set callee_ref_nodes; + for (const NameAttrList& call_target : call_target_list) { + const OpRegistrationData* op_reg; + if (OpRegistry::Global()->LookUp(call_target.name(), &op_reg).ok()) { + const OpDef& op = op_reg->op_def; + if (absl::c_any_of(op.output_arg(), [](const OpDef::ArgDef arg) { + return arg.is_ref(); + })) { + return true; + } + continue; + } + + callee_ref_nodes.clear(); + FunctionLibraryRuntime::Handle handle; + if (!lib_runtime + ->Instantiate(call_target.name(), AttrSlice(&call_target.attr()), + &handle) + .ok()) { + VLOG(2) << "Could not find " << call_target.name() + << " in the function library."; + // Since we don't know the semantic of `n` we don't know if this is an + // error. We return true to signal a conservative answer. + return true; + } + + auto release_handle_on_return = gtl::MakeCleanup( + [&] { TF_CHECK_OK(lib_runtime->ReleaseHandle(handle)); }); + + const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); + TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection( + *fbody->graph, lib_runtime, direction, depth + 1, &callee_ref_nodes)); + + // We could possibly use something cheaper than + // GetNodesRelatedToRefVariablesInDirection since we only care about the + // size of `callee_ref_nodes` but for now we don't ceare. + if (!callee_ref_nodes.empty()) { + return true; + } + } + + return false; +} + +// Helper for GetNodesRelatedToRefVariables that traverses the graph in one +// direction. +Status GetNodesRelatedToRefVariablesInDirection( + const Graph& graph, FunctionLibraryRuntime* lib_runtime, + Direction direction, int depth, absl::flat_hash_set* result) { + std::vector nodes_in_order; + if (direction == Direction::kForward) { + GetReversePostOrder(graph, &nodes_in_order, + /*stable_comparator=*/NodeComparatorName()); + } else { + GetPostOrder(graph, &nodes_in_order, + /*stable_comparator=*/NodeComparatorName()); + } + + int old_result_size; + int iterations = 0; + + const int kMaxIterations = 10 * 1000; + + std::vector callee_has_ref_nodes_cache; + callee_has_ref_nodes_cache.resize(graph.num_node_ids()); + + auto does_callee_have_ref_nodes = [&](Node* n) -> xla::StatusOr { + if (iterations == 1) { + TF_ASSIGN_OR_RETURN( + bool callee_has_ref_nodes, + DoesAnyCalleeHaveRefNodes(GetCallTargetListFromNode(*n, lib_runtime), + lib_runtime, direction, depth)); + callee_has_ref_nodes_cache[n->id()] = callee_has_ref_nodes; + return callee_has_ref_nodes; + } else { + return {callee_has_ref_nodes_cache[n->id()]}; + } + }; + + do { + TF_RET_CHECK(iterations++ < kMaxIterations) << "infinite loop?"; + + old_result_size = result->size(); + for (Node* n : nodes_in_order) { + if (n->IsSource() || n->IsSink()) { + continue; + } + + bool inserted_n = false; + const EdgeSet& edges = + direction == Direction::kForward ? n->in_edges() : n->out_edges(); + for (const Edge* e : edges) { + if (result->contains(direction == Direction::kForward ? e->src() + : e->dst())) { + result->insert(n); + inserted_n = true; + break; + } + } + + if (inserted_n) { + continue; + } + + if (direction == Direction::kForward && + absl::c_any_of(n->output_types(), IsRefType)) { + result->insert(n); + continue; + } + + TF_ASSIGN_OR_RETURN(bool callee_has_ref_nodes, + does_callee_have_ref_nodes(n)); + if (callee_has_ref_nodes) { + result->insert(n); + continue; + } + } + + // Loop until convergence. + } while (result->size() != old_result_size); + + VLOG(2) << "# iterations = " << iterations; + + return Status::OK(); +} +} // namespace + +xla::StatusOr> GetNodesRelatedToRefVariables( + const Graph& graph, FunctionLibraryRuntime* lib_runtime) { + absl::flat_hash_set result; + TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection( + graph, lib_runtime, Direction::kForward, 0, &result)); + TF_RETURN_IF_ERROR(GetNodesRelatedToRefVariablesInDirection( + graph, lib_runtime, Direction::kBackward, 0, &result)); + + VLOG(1) << "GetNodesRelatedToRefVariables() found " << result.size() + << " nodes"; + return result; +} + // Register a callback for querying XlaGlobalJitLevel. REGISTER_XLA_CONFIG_GETTER(GetXlaGlobalJitLevel); diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index 97fe80258a1..e2a1d159336 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/xla_activity.pb.h" @@ -94,6 +95,14 @@ bool IsShapeConsumerOp(const Node& node); // `XlaAutoClusteringSummary` for details. XlaAutoClusteringSummary GetXlaAutoClusteringSummary(const Graph& graph); +// Returns the set of nodes that have a path to or from nodes that may have ref +// variables as input or output. +// +// We assume each node has a trivial path to itself so the returned set includes +// all of the nodes that have ref variables as input or output. +xla::StatusOr> GetNodesRelatedToRefVariables( + const Graph& graph, FunctionLibraryRuntime* lib_runtime); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index 571d247c39b..aa87a958c1b 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -19,8 +19,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/algorithm.h" @@ -29,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -130,5 +134,155 @@ TEST(IsSingleGpuGraph, ReturnsFalseForMultiGpuGraph) { EXPECT_FALSE(IsSingleGpuGraph(*root.graph())); } + +xla::StatusOr> GetNodesRelatedToRefVarsSorted( + const Scope& scope, FunctionLibraryDefinition* flib_def = nullptr) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def_local(OpRegistry::Global(), flib); + if (flib_def == nullptr) { + flib_def = &flib_def_local; + } + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + + std::unique_ptr pflr( + new ProcessFunctionLibraryRuntime(nullptr, Env::Default(), + TF_GRAPH_DEF_VERSION, flib_def, + OptimizerOptions{})); + FunctionLibraryRuntime* lib_runtime = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + TF_ASSIGN_OR_RETURN(absl::flat_hash_set nodes_related_to_ref_vars, + GetNodesRelatedToRefVariables(*graph, lib_runtime)); + + std::vector names; + absl::c_transform(nodes_related_to_ref_vars, std::back_inserter(names), + [](Node* n) { return n->name(); }); + absl::c_sort(names); + return names; +} + +void CreateSubgraphTouchingRefVar(const Scope& s) { + Output variable = + ops::Variable(s.WithOpName("variable"), PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(s.WithOpName("read_ref_var"), variable); + Output neg = ops::Negate(s.WithOpName("negate_ref"), read); + Output add = ops::Add(s.WithOpName("add_ref"), neg, neg); + + Output constant = + ops::Const(s.WithOpName("constant_ref"), Input::Initializer(0.0)); + s.graph()->AddControlEdge(constant.node(), variable.node()); +} + +void CreateSubgraphNotTouchingRefVar(const Scope& s) { + Output constant = + ops::Const(s.WithOpName("constant_normal"), Input::Initializer(0.0)); + Output neg = ops::Negate(s.WithOpName("negate_normal"), constant); + Output add = ops::Add(s.WithOpName("add_normal"), neg, neg); +} + +void CreateSubgraphCallingFunctionWithRefVar(const Scope& s) { + NameAttrList ref_float_function; + ref_float_function.set_name("RefFloatFn"); + ops::PartitionedCall call(s.WithOpName("RefFloat"), {absl::Span{}}, + {DT_FLOAT}, ref_float_function); + Output constant = + ops::Const(s.WithOpName("constant_ref_pco"), Input::Initializer(0.0)); + s.graph()->AddControlEdge(call.operation.node(), constant.node()); +} + +void CreateSubgraphCallingFunctionWithoutRefVar(const Scope& s) { + NameAttrList regular_float_function; + regular_float_function.set_name("RegularFloatFn"); + ops::PartitionedCall call(s.WithOpName("RegularFloat"), {absl::Span{}}, + {DT_FLOAT}, regular_float_function); + Output constant = + ops::Const(s.WithOpName("constant_normal_pco"), Input::Initializer(0.0)); + s.graph()->AddControlEdge(call.operation.node(), constant.node()); +} + +void AddRefFunctionFunctionDef(FunctionDefLibrary* fdef_lib) { + FunctionDef make_ref_float = FunctionDefHelper::Define( + "RefFloatFn", {}, {"r:float"}, {}, + {{{"var"}, + "VariableV2", + {}, + {{"dtype", DT_FLOAT}, {"shape", TensorShape({})}}}, + {{"r"}, "Identity", {"var"}, {{"T", DT_FLOAT}}}}); + *fdef_lib->add_function() = make_ref_float; +} + +void AddRegularFunctionFunctionDef(FunctionDefLibrary* fdef_lib) { + Tensor seven(DT_FLOAT, {}); + seven.scalar()() = 7; + FunctionDef make_regular_float = FunctionDefHelper::Define( + "RegularFloatFn", {}, {"r:float"}, {}, + {{{"r"}, "Const", {}, {{"dtype", DT_FLOAT}, {"value", seven}}}}); + *fdef_lib->add_function() = make_regular_float; +} + +TEST(NodesRelatedToRefVariables, Basic) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary fdef_lib; + + CreateSubgraphTouchingRefVar(root); + CreateSubgraphNotTouchingRefVar(root); + + AddRefFunctionFunctionDef(&fdef_lib); + CreateSubgraphCallingFunctionWithRefVar(root); + + AddRegularFunctionFunctionDef(&fdef_lib); + CreateSubgraphCallingFunctionWithoutRefVar(root); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib); + + TF_ASSERT_OK_AND_ASSIGN(std::vector names, + GetNodesRelatedToRefVarsSorted(root, &flib_def)); + + std::vector expected({ + "RefFloat", + "add_ref", + "constant_ref", + "constant_ref_pco", + "negate_ref", + "read_ref_var", + "variable", + }); + + EXPECT_EQ(names, expected); +} + +Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) { + s = s.NewSubScope(std::string(loop_name)); + ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name); + ops::Merge merge(s.WithOpName("merge"), {init_value, init_value}); + Output next_iteration = + ops::NextIteration(s.WithOpName("next_itr"), merge.output); + return s.graph()->UpdateEdge(next_iteration.node(), 0, merge.output.node(), + 1); +} + +TEST(NodesRelatedToRefVariables, Cycles) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + TF_ASSERT_OK( + MakeLoop(root, ops::Identity(root.WithOpName("read_ref_var"), variable), + "ref_loop")); + TF_ASSERT_OK(MakeLoop( + root, ops::Const(root.WithOpName("constant"), Input::Initializer(0.0)), + "normal_loop")); + + TF_ASSERT_OK_AND_ASSIGN(std::vector names, + GetNodesRelatedToRefVarsSorted(root)); + std::vector expected({"read_ref_var", "ref_loop/init_value", + "ref_loop/merge", "ref_loop/next_itr", + "variable"}); + + EXPECT_EQ(names, expected); +} } // namespace } // namespace tensorflow