From fbbb83b995e3785dfd991248b0a4271aaf2d595c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Mar 2020 10:09:48 -0700 Subject: [PATCH] Const analysis should peek into PartitionedCall and StatefulPartitionedCall. PiperOrigin-RevId: 300571637 Change-Id: I4aef56c80e8bd2f14152f49aaf69778c2d916315 --- .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/tf2xla/BUILD | 3 - tensorflow/compiler/tf2xla/const_analysis.cc | 71 +++++-------------- tensorflow/compiler/tf2xla/const_analysis.h | 17 ++--- .../compiler/tf2xla/const_analysis_test.cc | 56 --------------- .../python/eager/def_function_xla_jit_test.py | 18 ----- 6 files changed, 21 insertions(+), 147 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 1f8f2cecfbc..45ce68ba9c0 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -99,8 +99,7 @@ Status XlaCompileOnDemandOp::MustArgumentBeConstant( // TODO(jmolloy): This could be expensive, so memoize. std::vector constant_input_indices; TF_RETURN_IF_ERROR(GetCompileTimeConstInputs( - op_kernel, &constant_input_indices, flib_runtime, - /*cached_arg_indices=*/nullptr)); + op_kernel, &constant_input_indices, flib_runtime)); *result = absl::c_binary_search(constant_input_indices, argument_idx); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd47337e401..a6f88df7e40 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -373,7 +373,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -611,12 +610,10 @@ tf_cc_test( ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", "//tensorflow/cc:ops", "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 5b420fd4f45..48513a43fb3 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -85,10 +85,10 @@ Status CondConstInputIndices( return Status::OK(); } -Status GetCompileTimeConstInputs( - const NodeDef& node, const OpKernel* op_kernel, const OpDef* op_def, - std::vector* const_input_idxs, FunctionLibraryRuntime* flib_runtime, - GraphConstArgIndicesCache* cached_arg_indices) { +Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, + const OpDef* op_def, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { DCHECK(op_def != nullptr || op_kernel != nullptr); // TODO(b/124403063): Implement similar functionality for function call nodes. if (node.op() == "While" || node.op() == "StatelessWhile") { @@ -106,12 +106,10 @@ Status GetCompileTimeConstInputs( std::vector compile_time_const_arg_indices(num_inputs); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *(fcond->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime, - [](const Edge&) { return true; }, cached_arg_indices)); + /*compile_time_const_nodes=*/nullptr, flib_runtime)); TF_RETURN_IF_ERROR(BackwardsConstAnalysis( *(fbody->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime, - [](const Edge&) { return true; }, cached_arg_indices)); + /*compile_time_const_nodes=*/nullptr, flib_runtime)); for (int i = 0; i < num_inputs; i++) { if (compile_time_const_arg_indices[i]) { // Check that this input is actually a loop invariant. @@ -147,22 +145,6 @@ Status GetCompileTimeConstInputs( TF_RETURN_IF_ERROR( GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies)); return CondConstInputIndices(branch_bodies, const_input_idxs, flib_runtime); - } else if (node.op() == "PartitionedCall" || - node.op() == "StatefulPartitionedCall") { - const FunctionBody* fbody; - TF_RETURN_IF_ERROR(GetFunctionBody(flib_runtime, node, "f", &fbody)); - int num_inputs = fbody->fdef.signature().input_arg_size(); - std::vector compile_time_const_arg_indices(num_inputs); - TF_RETURN_IF_ERROR(BackwardsConstAnalysis( - *(fbody->graph), &compile_time_const_arg_indices, - /*compile_time_const_nodes=*/nullptr, flib_runtime, - [](const Edge&) { return true; }, cached_arg_indices)); - for (int i = 0; i < num_inputs; i++) { - if (compile_time_const_arg_indices[i]) { - const_input_idxs->push_back(i); - } - } - return Status::OK(); } else if (op_def != nullptr) { return XlaOpRegistry::CompileTimeConstantInputs(node, *op_def, const_input_idxs); @@ -172,13 +154,12 @@ Status GetCompileTimeConstInputs( } } -Status GetCompileTimeConstInputs( - const Node* node, std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime, - GraphConstArgIndicesCache* cached_arg_indices) { +Status GetCompileTimeConstInputs(const Node* node, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr, &node->op_def(), const_input_idxs, - flib_runtime, cached_arg_indices); + flib_runtime); } } // namespace @@ -189,17 +170,7 @@ Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, FunctionLibraryRuntime* flib_runtime, - std::function edge_filter, - GraphConstArgIndicesCache* cached_arg_indices) { - // Avoid exponential runtime by explicit memoization: can do this only - // for the nested calls which don't have `compile_time_const_nodes` set. - if (!compile_time_const_nodes && cached_arg_indices && - cached_arg_indices->contains(&g)) { - VLOG(3) << "Memoized constant arg indices for the graph: " << &g; - *compile_time_const_arg_indices = cached_arg_indices->at(&g); - return Status::OK(); - } - + std::function edge_filter) { std::vector compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -208,11 +179,6 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_nodes = &compile_time_const_nodes_impl; } - GraphConstArgIndicesCache cached_arg_indices_impl; - if (!cached_arg_indices) { - cached_arg_indices = &cached_arg_indices_impl; - } - Status status; auto visit = [&](Node* node) { if (!status.ok()) return; @@ -255,8 +221,7 @@ Status BackwardsConstAnalysis(const Graph& g, // Mark any compile-time constant operator arguments as const. std::vector const_input_idxs; - status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime, - cached_arg_indices); + status = GetCompileTimeConstInputs(node, &const_input_idxs, flib_runtime); if (!status.ok()) { return; @@ -287,19 +252,15 @@ Status BackwardsConstAnalysis(const Graph& g, // acyclic graph. DFS(g, /*enter=*/{}, /*leave=*/visit, NodeComparatorName{}, [](const Edge& edge) { return !edge.src()->IsNextIteration(); }); - if (cached_arg_indices && compile_time_const_arg_indices) { - cached_arg_indices->emplace(&g, *compile_time_const_arg_indices); - } return status; } -Status GetCompileTimeConstInputs( - const OpKernel* op_kernel, std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime, - GraphConstArgIndicesCache* cached_arg_indices) { +Status GetCompileTimeConstInputs(const OpKernel* op_kernel, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { return GetCompileTimeConstInputs(op_kernel->def(), op_kernel, /*op_def=*/nullptr, const_input_idxs, - flib_runtime, cached_arg_indices); + flib_runtime); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index 0db6d44c6a7..587347ff8a5 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -18,15 +18,11 @@ limitations under the License. #include -#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -using GraphConstArgIndicesCache = - absl::flat_hash_map>; - // Backwards dataflow analysis that finds nodes in a graph that must be // compile-time constants for us to be able to lower the graph to XLA. // @@ -38,24 +34,19 @@ using GraphConstArgIndicesCache = // `compile_time_const_nodes`, if `compile_time_const_nodes` is not null. // // Only propagate const-ness along edges for which `edge_filter` returns true. -// -// `cached_arg_indices` is a memoization cache used for nested invocations on -// function calls, which caches what argument indices need to be constant for -// each associated graph (e.g. called function). Status BackwardsConstAnalysis( const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, FunctionLibraryRuntime* flib_runtime, - std::function edge_filter = - [](const Edge& e) { return true; }, - GraphConstArgIndicesCache* cached_arg_indices = nullptr); + std::function edge_filter = [](const Edge& e) { + return true; + }); // Given an op kernel and function library runtime, return all the indices of // inputs that need to be compile time constant. Status GetCompileTimeConstInputs(const OpKernel* op_kernel, std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime, - GraphConstArgIndicesCache* cached_arg_indices); + FunctionLibraryRuntime* flib_runtime); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc index 936b74f7b33..ed5f004550f 100644 --- a/tensorflow/compiler/tf2xla/const_analysis_test.cc +++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc @@ -19,14 +19,11 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/functional_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -92,59 +89,6 @@ TEST(ConstAnalysisTest, TopologicalOrder) { } } -void TestFunctionCall(bool is_stateful_partitioned_call) { - FunctionDef callee = FunctionDefHelper::Define( - "Callee", {"t:float", "shape:int32"}, {"result:float"}, {}, - {{{"result"}, "Reshape", {"t", "shape"}, {{"T", DT_FLOAT}}}}); - - FunctionDefLibrary flib; - *flib.add_function() = callee; - FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib); - - Scope root = Scope::NewRootScope().ExitOnError(); - - auto arg0 = ops::_Arg(root.WithOpName("tensor"), DT_FLOAT, 0); - auto arg1 = ops::_Arg(root.WithOpName("shape"), DT_INT32, 1); - - NameAttrList call_attrs; - call_attrs.set_name("Callee"); - if (is_stateful_partitioned_call) { - ops::StatefulPartitionedCall b(root.WithOpName("Call"), - {Output(arg0), Output(arg1)}, {DT_FLOAT}, - call_attrs); - } else { - ops::PartitionedCall b(root.WithOpName("Call"), - {Output(arg0), Output(arg1)}, {DT_FLOAT}, - call_attrs); - } - - Graph graph(&flib_def); - TF_ASSERT_OK(root.ToGraph(&graph)); - - OptimizerOptions opts; - std::unique_ptr pflr( - new ProcessFunctionLibraryRuntime(nullptr, Env::Default(), - /*config=*/nullptr, - TF_GRAPH_DEF_VERSION, &flib_def, opts)); - FunctionLibraryRuntime* lib_runtime = - pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); - - std::vector const_args(2, false); - TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args, - /*compile_time_const_nodes=*/nullptr, - lib_runtime)); - - EXPECT_EQ(const_args, std::vector({false, true})); -} - -TEST(ConstAnalysisTest, PartitionedCall) { - TestFunctionCall(/*is_stateful_partitioned_call=*/false); -} - -TEST(ConstAnalysisTest, StatefulPartitionedCall) { - TestFunctionCall(/*is_stateful_partitioned_call=*/true); -} - TEST(ConstAnalysisTest, DontFollowControlDependencies) { Scope root = Scope::NewRootScope(); diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index adff0858488..16d57ef36da 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -232,24 +232,6 @@ class DefFunctionTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, 'not compilable'): c.f1(inputs) - def testMustBeConstantPropagation(self): - if test.is_built_with_rocm(): - return - - @def_function.function(experimental_compile=True) - def f(): - return constant_op.constant([0, 2, 1], dtype=dtypes.int32) - - @def_function.function(experimental_compile=True) - def g(a, b): - return array_ops.transpose(a, b) - - @def_function.function - def z(): - return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f()) - - z() - if __name__ == '__main__': ops.enable_eager_execution()