diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
index 5c115f7ae31..a8a65dde131 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.platform import test
@@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
   def testMapAndBatchControlFlow(self, numa_aware):
 
     def map_fn(x):
-      previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
-      control_flow_ops.ENABLE_COND_V2 = True
+      previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
+      control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
       return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
-      control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
+      control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
       return return_value
 
     dataset = dataset_ops.Dataset.range(100).apply(
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index df3cebd2e0c..0e48d3c8758 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import versions
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import tf_logging as logging
@@ -409,42 +408,12 @@ def enable_control_flow_v2(fn):
   """
 
   def wrapper(*args, **kwargs):
-    enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
-    enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
-    enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
-    control_flow_ops.ENABLE_COND_V2 = True
-    control_flow_ops.ENABLE_WHILE_V2 = True
-    tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
+    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
     try:
       fn(*args, **kwargs)
     finally:
-      control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
-      control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
-      tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
-
-  return wrapper
-
-
-def enable_tensor_array_v2(fn):
-  """Decorator for enabling _GraphTensorArrayV2 on a test.
-
-  Note this enables _GraphTensorArrayV2 after running the test class's
-  setup/teardown methods.
-
-  Args:
-    fn: the function to be wrapped
-
-  Returns:
-    The wrapped function
-  """
-
-  def wrapper(*args, **kwargs):
-    enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
-    tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
-    try:
-      fn(*args, **kwargs)
-    finally:
-      tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
+      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
 
   return wrapper
 
@@ -493,7 +462,7 @@ def with_control_flow_v2(cls):
   Returns:
     cls with new test methods added
   """
-  if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
+  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
     return cls
 
   for name, value in cls.__dict__.copy().items():
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 0fd293ebba3..21ded25a116 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_array_ops
@@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase):
       v1_msg = "The two structures don't have the same nested structure"
       v2_msg = "Outputs of true_fn and false_fn must have the same structure"
       with self.assertRaisesRegexp(
-          ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg):
+          ValueError,
+          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
         r = control_flow_ops.cond(pred, fn1, fn2)
         self.evaluate(r)
 
@@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase):
       self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
 
       # v1 control flow gets None second derivative for some reason.
-      if not control_flow_ops.ENABLE_COND_V2:
+      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
         self.assertIsNone(grad_grad)
         return
 
@@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase):
 
     # In defuns, all prints should execute in program order.
     # This doesn't work with legacy control flow.
-    if control_flow_ops.ENABLE_COND_V2:
+    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
 
       @eager_function.defun
       def cond():
@@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase):
 
     # In defuns, all prints should execute in program order.
     # This doesn't work with legacy control flow.
-    if control_flow_ops.ENABLE_WHILE_V2:
+    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
 
       @eager_function.defun
       def while_loop():
@@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase):
     gs = gradients_impl.gradients(loop_no_xla, v)
     self.evaluate(gs)  # This should execute without error.
 
-    if control_flow_ops.ENABLE_WHILE_V2:
+    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
       xla_context = control_flow_ops.XLAControlFlowContext()
       xla_context.Enter()
       with self.assertRaisesRegexp(
@@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase):
           lambda i, x: (i + 1, v * x), (0, 1.0),
           maximum_iterations=max_iter_holder[0])
 
-    if control_flow_ops.ENABLE_WHILE_V2:
+    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
       xla_context = control_flow_ops.XLAControlFlowContext()
       xla_context.Enter()
       with self.assertRaisesRegexp(
@@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase):
       self.assertEqual(sess.run(grad, {pred: True}), 8.0)
       self.assertEqual(sess.run(grad, {pred: False}), 0.0)
 
-      if not control_flow_ops.ENABLE_WHILE_V2:
+      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
         return
 
       self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
@@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase):
     #   outer_loop(x) = g(g(x)) = 4x + 81
     #   outer_loop'(x) = 4
     # Note that v1 control flow gets 4.0 as well if the cond is removed.
-    if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
+    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
       self.assertEqual(grad, 4.0)
 
   def testWhile_NestedInput(self):
@@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase):
 
     result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
     grad_theta = gradients_impl.gradients(result, theta)
-    if not control_flow_ops.ENABLE_WHILE_V2:
+    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
       with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
         gradients_impl.gradients(grad_theta, theta)
     grad_theta_stopped = array_ops.stop_gradient(grad_theta)
@@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase):
       self.assertEqual(r[1].eval(), 65536.0)
       self.assertEqual(grad.eval(), 524288.0)
       # while_v2 does not have stacks.
-      if not control_flow_ops.ENABLE_WHILE_V2:
+      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
         self.assertEqual(
             len([op for op in x.graph.get_operations() if op.type == "StackV2"
                 ]), 1)
diff --git a/tensorflow/python/kernel_tests/control_flow_util_v2_test.py b/tensorflow/python/kernel_tests/control_flow_util_v2_test.py
index d0374a77005..08d3214e288 100644
--- a/tensorflow/python/kernel_tests/control_flow_util_v2_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_util_v2_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.eager import function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import control_flow_util_v2
 from tensorflow.python.platform import test
 
@@ -30,14 +31,11 @@ from tensorflow.python.platform import test
 class ControlFlowUtilV2Test(test.TestCase):
 
   def setUp(self):
-    self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
-    self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
-    control_flow_ops.ENABLE_COND_V2 = True
-    control_flow_ops.ENABLE_WHILE_V2 = True
+    self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
 
   def tearDown(self):
-    control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old
-    control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old
 
   def _create_control_flow(self, expect_in_defun):
     """Helper method for testInDefun."""
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index 88625841bcc..6d8e3e83566 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import gen_data_flow_ops
 from tensorflow.python.ops import gradients_impl
@@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase):
 
   @test_util.run_deprecated_v1
   def testSkipEagerTensorArrayGradGrad(self):
-    if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
+    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
       self.skipTest("Legacy TensorArray does not support double derivatives.")
     with self.test_session(use_gpu=True) as session:
       x = constant_op.constant(4.0)
@@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase):
     with self.session(use_gpu=True):
       ta = _make_ta(3, "foo", dtype=dtypes.float32)
       # Test writing the wrong datatype
-      if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
           not context.executing_eagerly()):
         error_msg = ("Invalid data types; op elements string but list elements "
                      "float")
@@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase):
       with self.assertRaisesOpError(error_msg):
         self.evaluate(ta.write(0, "wrong_type_scalar").flow)
 
-      if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
           not context.executing_eagerly()):
         error_msg = "Trying to modify element -1 in a list with 3 elements."
       else:
@@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase):
       with self.assertRaisesOpError(error_msg):
         self.evaluate(ta.write(-1, 3.0).flow)
 
-      if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
           not context.executing_eagerly()):
         error_msg = "Trying to modify element 3 in a list with 3 elements"
       else:
@@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase):
 
       # Test reading wrong datatype (only possible when constructing graphs).
       if (not context.executing_eagerly() and
-          not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2):
+          not control_flow_util.ENABLE_CONTROL_FLOW_V2):
         r0_bad = gen_data_flow_ops.tensor_array_read_v3(
             handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
         with self.assertRaisesOpError(
             "TensorArray dtype is float but Op requested dtype double."):
           self.evaluate(r0_bad)
 
-      if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
           not context.executing_eagerly()):
         error_msg = "Trying to access element -1 in a list with 3 elements."
       else:
@@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase):
       with self.assertRaisesOpError(error_msg):
         self.evaluate(ta.read(-1))
 
-      if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+      if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
           not context.executing_eagerly()):
         error_msg = "Trying to access element 3 in a list with 3 elements."
       else:
@@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase):
           ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
 
       error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
-                   if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
+                   if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
                    not in_eager_mode else
                    r"Expected sum of lengths to be equal to values.shape\[0\], "
                    r"but sum of lengths is 1 and value's shape is: \[3\]")
@@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase):
         self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
 
       ta = _make_ta(1, "baz")
-      if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode:
+      if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
         with self.assertRaisesRegexp(
             ValueError, "Shape must be at least rank 1 but is rank 0"):
           self.evaluate(ta.split(1.0, [1]).flow)
@@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase):
         ):
           self.evaluate(ta.split(1.0, [1]).flow)
 
-      if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode:
+      if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
         ta = _make_ta(2, "buz")
         with self.assertRaisesOpError(
             r"TensorArray's size is not equal to the size of lengths "
@@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase):
     # self._testWhileLoopWritePackGradients(
     #     dynamic_size=False, dtype=tf.int64)
 
-  @test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA")
-  @test_util.enable_tensor_array_v2
-  def testWhileLoopV1WithTensorArrayV2(self):
-    size = 3
-    ta = tensor_array_ops.TensorArray(
-        dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar())
-
-    def Body(counter, ta):
-      return counter + 1, ta.write(counter, counter)
-
-    _, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta])
-
-    for i in range(size):
-      self.assertEqual(self.evaluate(ta.read(i)), i)
-
   @test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
   @test_util.run_v1_only("b/117943489")
   def testSkipEagerWhileLoopDynamicWritePackGradients(self):
@@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase):
         self.assertEqual((2, 2), w0.read(1).get_shape())
       else:
         self.assertEqual(r0.get_shape().ndims, None)
-        if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
+        if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
           self.assertEqual(
               tensor_shape.TensorShape(
                   ta1.handle.op.get_attr("element_shape")).ndims, None)
@@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase):
           "TensorArray has size zero, but element shape <unknown> is not "
           "fully defined. Currently only static shapes are supported when "
           "packing zero-size TensorArrays.")
-      with self.assertRaisesOpError(v2_msg if tensor_array_ops
-                                    .ENABLE_TENSOR_ARRAY_V2 else v1_msg):
+      with self.assertRaisesOpError(
+          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
         ta.stack().eval()
 
   @test_util.run_v1_only("b/120545219")
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index b7e50c1dae5..99216d7fb15 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -24,13 +24,11 @@ from __future__ import print_function
 import abc
 import collections
 import functools
-import os
 
 import six
 
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.protobuf import control_flow_pb2
-from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(),
 while_v2 = LazyLoader("while_v2", globals(),
                       "tensorflow.python.ops.while_v2")
 
-ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0"
-ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
-
 # We override the 'tuple' for a control flow op, so we keep python's
 # existing 'tuple' for later use in this module.
 _basetuple = tuple
@@ -2052,7 +2047,7 @@ def cond(pred,
   ```
 
   """
-  if ENABLE_COND_V2 and not context.executing_eagerly():
+  if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
     return cond_v2.cond_v2(pred, true_fn, false_fn, name)
 
   # We needed to make true_fn/false_fn keyword arguments for
@@ -3487,7 +3482,7 @@ def while_loop(cond,
   ```
 
   """
-  if ENABLE_WHILE_V2 and not context.executing_eagerly():
+  if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
     return while_v2.while_loop(
         cond,
         body,
diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py
index 9ba5ff2c0f8..9dd1e6673b8 100644
--- a/tensorflow/python/ops/control_flow_ops_benchmark.py
+++ b/tensorflow/python/ops/control_flow_ops_benchmark.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.platform import test
@@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark):
               iters=self.NUM_ITERS)
 
   def benchmark_cond_v1_defun(self):
-    old_val = control_flow_ops.ENABLE_COND_V2
-    control_flow_ops.ENABLE_COND_V2 = False
+    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
     self._benchmark_defun()
-    control_flow_ops.ENABLE_COND_V2 = old_val
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
 
   def benchmark_cond_v2_defun(self):
-    old_val = control_flow_ops.ENABLE_COND_V2
-    control_flow_ops.ENABLE_COND_V2 = True
+    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
     self._benchmark_defun()
-    control_flow_ops.ENABLE_COND_V2 = old_val
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
 
   def benchmark_cond_v1_graph(self):
-    old_val = control_flow_ops.ENABLE_COND_V2
-    control_flow_ops.ENABLE_COND_V2 = False
+    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
     self._benchmark_graph()
-    control_flow_ops.ENABLE_COND_V2 = old_val
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
 
   def benchmark_cond_v2_graph(self):
-    old_val = control_flow_ops.ENABLE_COND_V2
-    control_flow_ops.ENABLE_COND_V2 = True
+    old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
     self._benchmark_graph()
-    control_flow_ops.ENABLE_COND_V2 = old_val
+    control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
 
 if __name__ == "__main__":
   ops.enable_eager_execution()
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index cb628f4aa64..1747f06109d 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -23,10 +23,18 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 import traceback
 
+from tensorflow.python import tf2
 from tensorflow.python.platform import tf_logging as logging
 
+ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or
+                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
+                          os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
+                          os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
+                          os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
+
 
 def IsInXLAContext(op):
   try:
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index d1516949517..85333ee6b56 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -20,10 +20,8 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
-import os
 import weakref
 
-from tensorflow.python import tf2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -32,6 +30,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import gen_control_flow_ops
 from tensorflow.python.ops import gen_data_flow_ops
 from tensorflow.python.ops import list_ops
@@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use
 from tensorflow.python.util.tf_export import tf_export
 
 
-ENABLE_TENSOR_ARRAY_V2 = (
-    tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None)
-
-
 # _GraphTensorArray accesses many of the hidden generated ops, but is in
 # fact built to wrap these methods.
 # pylint: disable=protected-access
@@ -1013,7 +1008,7 @@ class TensorArray(object):
     if context.executing_eagerly():
       implementation = _EagerTensorArray
     else:
-      if ENABLE_TENSOR_ARRAY_V2:
+      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
         implementation = _GraphTensorArrayV2
       else:
         implementation = _GraphTensorArray
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index d00c158d156..f7566bac9bd 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -52,13 +52,6 @@ from tensorflow.python.util import nest
 # to them and then pass those in as data inputs. This should probably be
 # handled in the CapturingGraph itself.
 
-# Op types that output a resource tensor representing a TensorArray handle.
-TENSOR_ARRAY_HANDLE_OPS = (
-    "TensorArrayV3",
-    "TensorArrayGradV3",
-    "TensorArrayGradWithShape",
-)
-
 
 def while_loop(cond,
                body,
@@ -257,24 +250,19 @@ def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
       "_maximum_iterations") if _is_in_xla_context() else None
   assert not _is_in_xla_context() or maximum_iterations is not None
 
-  # Set the incoming gradient of TensorArray handles to None. The gradient
-  # implementation currently assumes all resource tensors correspond to float32
-  # ResourceVariables, which can lead to runtime shape errors when used with a
-  # TensorArray. This is a workaround until TensorArrays are reimplemented with
-  # TensorLists instead of resources.
-  # Also set the incoming gradient of non-trainable inputs to None. It is
-  # possible that we receive non-None gradients for non-trainable types in
-  # nested while loops because we accumulate outputs of the inner while as
-  # variant tensors which are trainable and hence receive zeros_like tensors in
-  # the gradient pass. The non-trainable tensors then receive the popped zeros
-  # tensor from this zeros variant. The gradient for the loop vars corresponding
-  # to these tensors is None or zeros (this happens only if the loop var is
-  # accumulated as well) in _grad_fn so we reset these.
+  # Set the incoming gradient of non-trainable inputs to None. It is possible
+  # that we receive non-None gradients for non-trainable types in nested while
+  # loops because we accumulate outputs of the inner while as variant tensors
+  # which are trainable and hence receive zeros_like tensors in the gradient
+  # pass. The non-trainable tensors then receive the popped zeros tensor from
+  # this zeros variant. The gradient for the loop vars corresponding to these
+  # tensors is None or zeros (this happens only if the loop var is accumulated
+  # as well) in _grad_fn so we reset these.
   # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
   # output grads in _grad_fn.
   grads = [
-      None if _is_tensor_array_handle(output) or not _is_trainable(output)
-      else grad for grad, output in zip(grads, body_graph.outputs)
+      None if not _is_trainable(output) else grad
+      for grad, output in zip(grads, body_graph.outputs)
   ]
 
   # Ensure that all non-resource trainable outputs have incoming gradients.
@@ -339,8 +327,7 @@ def _WhileGrad(op, *grads):  # pylint: disable=invalid-name
   # See comment in while_loop.
   outputs = [array_ops.identity(t) for t in outputs]
 
-  # Set None as the output gradient for tensors with None input gradient
-  # e.g. TensorArray handles.
+  # Set None as the output gradient for tensors with None input gradient.
   # outputs[0] is the loop counter.
   # outputs[1] is the total number of loop iterations.
   index = 2
@@ -853,28 +840,6 @@ def _graph_name(graph):
   return "Base"
 
 
-def _is_tensor_array_handle(tensor):
-  """Returns whether tensor is a TensorArray handle."""
-  if tensor.dtype != dtypes.resource:
-    return False
-
-  if tensor.op.type == "While":
-    # We assume that any resource outputs of a While op correspond to a captured
-    # resource input (as opposed to a loop variable specified by the user).
-    # NOTE(skyewm): we could actually check this, but I can't think of when you
-    # would have a resource loop variable.
-    tensor = tensor.op.inputs[tensor.value_index]
-
-  # TODO(b/118452219): add test coverage for this.
-  tensor = func_graph_module.maybe_captured(tensor)
-
-  if isinstance(tensor, ops.EagerTensor):
-    # Eager execution doesn't quite support legacy tensorarray
-    return False
-
-  return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
-
-
 def _pack_sequence_as(structure_with_tas, loop_vars):
   """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""