diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 9c378bd3b2b..4a76bd79513 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -721,6 +721,8 @@ tf_xla_py_test(
         "//tensorflow/compiler/tests:xla_test",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_util",
         "//tensorflow/python:framework_ops",
     ],
 )
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 261bf3f5862..e7b4a6f84b2 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -555,7 +555,18 @@ class Function(object):
       return self._python_function(*args, **kwds)
 
     tracing_count = self._get_tracing_count()
-    result = self._call(*args, **kwds)
+    if self._experimental_compile:
+      # V2 control flow relies on XLAControlFlowContext to generate a
+      # XLA-compatible function graph.
+      xla_context = control_flow_ops.XLAControlFlowContext()
+      try:
+        xla_context.Enter()
+        result = self._call(*args, **kwds)
+      finally:
+        xla_context.Exit()
+    else:
+      result = self._call(*args, **kwds)
+
     if tracing_count == self._get_tracing_count():
       self._call_counter.called_without_tracing()
       return result
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index c477f5e0532..5338725f88d 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -23,6 +23,8 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.platform import test
 
@@ -79,6 +81,34 @@ class DefFunctionTest(test.TestCase):
       # XLA support is not yet enabled for TF ROCm
       run_and_check(xla_func)
 
+  def testControlFlow(self):
+
+    @def_function.function(experimental_compile=True)
+    def f(x):
+      assert control_flow_util.GraphOrParentsInXlaContext(
+          ops.get_default_graph())
+      x = ops.convert_to_tensor(x)
+
+      def body(i, a):
+        return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2),
+                                            lambda: a + 3)
+
+      return control_flow_ops.while_loop(
+          lambda i, *_: i < 10,
+          body, (constant_op.constant(0), constant_op.constant(3.)),
+          maximum_iterations=10)[1]
+
+    @def_function.function(experimental_compile=True)
+    def g(x):
+      x = ops.convert_to_tensor(x)
+      with backprop.GradientTape() as tape:
+        tape.watch(x)
+        y = f(x)
+      return y, tape.gradient(y, x)
+
+    self.assertAllClose(40.0, f(2.0))
+    self.assertAllClose([40.0, 28.0], g(2.0))
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()