diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 55341c0a01f..37110442b26 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -350,6 +350,7 @@ cc_library(
         ":sharding_util",
         ":side_effect_util",
         ":tf2xla_util",
+        "//tensorflow/compiler/jit:common",
         "//tensorflow/compiler/jit:flags",
         "//tensorflow/compiler/jit:shape_inference",
         "//tensorflow/compiler/jit:xla_cluster_util",
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 3d6083621f4..1cf3e10b774 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -20,6 +20,7 @@ limitations under the License.
 
 #include "absl/memory/memory.h"
 #include "absl/types/variant.h"
+#include "tensorflow/compiler/jit/defs.h"
 #include "tensorflow/compiler/jit/flags.h"
 #include "tensorflow/compiler/jit/shape_inference.h"
 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
@@ -571,6 +572,10 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
   std::unique_ptr<Graph> graph(new Graph(options_.flib_def));
   CopyGraph(*fbody->graph, graph.get());
 
+  bool is_inside_mustcompile = false;
+  TryGetNodeAttr(AttrSlice(&fbody->fdef.attr()), kXlaMustCompileAttr,
+                 &is_inside_mustcompile);
+
   // Performs a first function inlining pass before shape inference, since
   // otherwise shape inference can't see inside functions and a comprehensive
   // shape_map, including function ops, is needed to constant-propagate Shape
@@ -622,6 +627,8 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
   graph_optimizer_options.inline_multi_device_functions = true;
   graph_optimizer_options.inline_impl_selection_group_functions = true;
   graph_optimizer_options.inline_with_single_device_body_placer = true;
+  graph_optimizer_options.ignore_noinline = is_inside_mustcompile;
+
   optimizer.Optimize(flib_runtime_, flib_runtime_->env(),
                      /*device=*/nullptr, &graph, graph_optimizer_options);
 
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 746930750ad..ae1a2daa788 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -42,7 +42,7 @@ void GraphOptimizer::Optimize(
     const NodePredicate& cse_consider_fn, const NodePredicate& cf_consider_fn,
     bool inline_multi_device_functions,
     bool inline_impl_selection_group_functions,
-    bool inline_with_single_device_body_placer) {
+    bool inline_with_single_device_body_placer, bool ignore_noinline) {
   Graph* g = graph->get();
   DumpGraph("Initial", g);
 
@@ -116,6 +116,11 @@ void GraphOptimizer::Optimize(
             .inline_impl_selection_group_functions = true;
       }
 
+      if (ignore_noinline) {
+        expand_inline_opts.multi_device_options.ignore_noinline = true;
+        expand_inline_opts.native_options.ignore_noinline = true;
+      }
+
       bool was_mutated = ExpandInlineFunctions(runtime, g, expand_inline_opts);
       if (was_mutated) {
         DumpGraph("ExpandInlineFunctions", g);
@@ -138,11 +143,11 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
                               const Device* device,
                               std::unique_ptr<Graph>* graph,
                               const Options& options) {
-  Optimize(runtime, env, device, graph, options.shape_map,
-           options.cse_consider_fn, options.cf_consider_fn,
-           options.inline_multi_device_functions,
-           options.inline_impl_selection_group_functions,
-           options.inline_with_single_device_body_placer);
+  Optimize(
+      runtime, env, device, graph, options.shape_map, options.cse_consider_fn,
+      options.cf_consider_fn, options.inline_multi_device_functions,
+      options.inline_impl_selection_group_functions,
+      options.inline_with_single_device_body_placer, options.ignore_noinline);
 }
 
 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 099ea8efa12..53bf532bd9c 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -58,6 +58,9 @@ class GraphOptimizer {
     // If true all functions will be inlined with a single device function
     // body placer strategy.
     bool inline_with_single_device_body_placer = false;
+
+    // If true, the _noinline attribute on functions and callers is ignored.
+    bool ignore_noinline = false;
   };
 
   explicit GraphOptimizer(const OptimizerOptions& opts);
@@ -81,7 +84,8 @@ class GraphOptimizer {
       const NodePredicate& cf_consider_fn = nullptr,
       bool inline_multi_device_functions = false,
       bool inline_impl_selection_group_functions = false,
-      bool inline_with_single_device_body_placer = false);
+      bool inline_with_single_device_body_placer = false,
+      bool ignore_noinline = false);
 
   const OptimizerOptions& options() { return opts_; }
 
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 5fdf0487333..b63a3b434d4 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -355,6 +355,27 @@ class DefFunctionTest(test.TestCase):
     self.assertAllClose([5.0, 5.0, 5.0], g())
     self.assertAllClose(compiled_g(), g())
 
+  def testTensorListConcatGradNestedCompile(self):
+
+    @def_function.function(experimental_compile=True)
+    def f(x):
+      ta = tensor_array_ops.TensorArray(
+          dtype=dtypes.float32, size=2, element_shape=[3])
+      ta = ta.write(0, 2 * x)
+      ta = ta.write(1, 3 * x)
+      return ta.concat()
+
+    @def_function.function(experimental_compile=True)
+    def g():
+      x = constant_op.constant([3.14, 2.68, 7.69])
+      with backprop.GradientTape() as tape:
+        tape.watch(x)
+        y = f(x)
+        out = tape.gradient(y, x)
+      return out
+
+    self.assertAllClose([5.0, 5.0, 5.0], g())
+
   def testCumsum(self):
 
     @def_function.function(experimental_compile=True)