diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 55341c0a01f..37110442b26 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -350,6 +350,7 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d6083621f4..1cf3e10b774 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/types/variant.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -571,6 +572,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr<Graph> graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + bool is_inside_mustcompile = false; + TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr, + &is_inside_mustcompile); + // Performs a first function inlining pass before shape inference, since // otherwise shape inference can't see inside functions and a comprehensive // shape_map, including function ops, is needed to constant-propagate Shape @@ -622,6 +627,8 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) { graph_optimizer_options.inline_multi_device_functions = true; graph_optimizer_options.inline_impl_selection_group_functions = true; graph_optimizer_options.inline_with_single_device_body_placer = true; + graph_optimizer_options.ignore_noinline = is_inside_mustcompile; + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), /*device=*/nullptr, &graph, graph_optimizer_options); diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 746930750ad..ae1a2daa788 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -42,7 +42,7 @@ void GraphOptimizer::Optimize( const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn, bool inline_multi_device_functions, bool inline_impl_selection_group_functions, - bool inline_with_single_device_body_placer) { + bool inline_with_single_device_body_placer, bool ignore_noinline) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -116,6 +116,11 @@ void GraphOptimizer::Optimize( .inline_impl_selection_group_functions = true; } + if (ignore_noinline) { + expand_inline_opts.multi_device_options.ignore_noinline = true; + expand_inline_opts.native_options.ignore_noinline = true; + } + bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts); if (was_mutated) { DumpGraph("ExpandInlineFunctions", g); @@ -138,11 +143,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env, const Device* device, std::unique_ptr<Graph>* graph, const Options& options) { - Optimize(runtime, env, device, graph, options.shape_map, - options.cse_consider_fn, options.cf_consider_fn, - options.inline_multi_device_functions, - options.inline_impl_selection_group_functions, - options.inline_with_single_device_body_placer); + Optimize( + runtime, env, device, graph, options.shape_map, options.cse_consider_fn, + options.cf_consider_fn, options.inline_multi_device_functions, + options.inline_impl_selection_group_functions, + options.inline_with_single_device_body_placer, options.ignore_noinline); } void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g, diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 099ea8efa12..53bf532bd9c 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -58,6 +58,9 @@ class GraphOptimizer { // If true all functions will be inlined with a single device function // body placer strategy. bool inline_with_single_device_body_placer = false; + + // If true, the _noinline attribute on functions and callers is ignored. + bool ignore_noinline = false; }; explicit GraphOptimizer(const OptimizerOptions& opts); @@ -81,7 +84,8 @@ class GraphOptimizer { const NodePredicate& cf_consider_fn = nullptr, bool inline_multi_device_functions = false, bool inline_impl_selection_group_functions = false, - bool inline_with_single_device_body_placer = false); + bool inline_with_single_device_body_placer = false, + bool ignore_noinline = false); const OptimizerOptions& options() { return opts_; } diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 5fdf0487333..b63a3b434d4 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -355,6 +355,27 @@ class DefFunctionTest(test.TestCase): self.assertAllClose([5.0, 5.0, 5.0], g()) self.assertAllClose(compiled_g(), g()) + def testTensorListConcatGradNestedCompile(self): + + @def_function.function(experimental_compile=True) + def f(x): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, size=2, element_shape=[3]) + ta = ta.write(0, 2 * x) + ta = ta.write(1, 3 * x) + return ta.concat() + + @def_function.function(experimental_compile=True) + def g(): + x = constant_op.constant([3.14, 2.68, 7.69]) + with backprop.GradientTape() as tape: + tape.watch(x) + y = f(x) + out = tape.gradient(y, x) + return out + + self.assertAllClose([5.0, 5.0, 5.0], g()) + def testCumsum(self): @def_function.function(experimental_compile=True)