From 19518ef98e7e655a867442270228d1d4ab9e4b29 Mon Sep 17 00:00:00 2001
From: Saurabh Saxena <srbs@google.com>
Date: Fri, 2 Aug 2019 19:58:31 -0700
Subject: [PATCH] Enable v2 control flow inside graph_mode in TF 2.0. Current
 behavior is to use v2 control flow only inside tf.function.

PiperOrigin-RevId: 261435951
---
 tensorflow/python/BUILD                       | 25 ++++++++++++
 tensorflow/python/compat/v2_compat.py         |  3 +-
 .../data/experimental/kernel_tests/BUILD      |  1 +
 .../experimental/kernel_tests/scan_test.py    |  4 ++
 tensorflow/python/keras/BUILD                 |  2 +-
 .../python/ops/control_flow_ops_test.py       |  3 ++
 tensorflow/python/ops/control_flow_util.py    |  5 ++-
 .../ops/control_flow_v2_disable_test.py       | 39 +++++++++++++++++++
 .../python/ops/control_flow_v2_enable_test.py | 38 ++++++++++++++++++
 9 files changed, 116 insertions(+), 4 deletions(-)
 create mode 100644 tensorflow/python/ops/control_flow_v2_disable_test.py
 create mode 100644 tensorflow/python/ops/control_flow_v2_enable_test.py

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 55bb8a2c116..68dba890e70 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2609,6 +2609,30 @@ tf_py_test(
     ],
 )
 
+tf_py_test(
+    name = "control_flow_v2_enable_test",
+    size = "small",
+    srcs = ["ops/control_flow_v2_enable_test.py"],
+    additional_deps = [
+        ":tf2",
+        ":control_flow_util",
+        ":client_testlib",
+        ":platform_test",
+    ],
+)
+
+tf_py_test(
+    name = "control_flow_v2_disable_test",
+    size = "small",
+    srcs = ["ops/control_flow_v2_disable_test.py"],
+    additional_deps = [
+        ":tf2",
+        ":control_flow_util",
+        ":client_testlib",
+        ":platform_test",
+    ],
+)
+
 py_library(
     name = "cond_v2",
     srcs = [
@@ -3926,6 +3950,7 @@ cuda_py_test(
         ":array_ops",
         ":cond_v2",
         ":control_flow_ops",
+        ":control_flow_v2_toggles",
         ":embedding_ops",
         ":framework_for_generated_wrappers",
         ":framework_test_lib",
diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py
index 85381089b7c..e2c4c6a4316 100644
--- a/tensorflow/python/compat/v2_compat.py
+++ b/tensorflow/python/compat/v2_compat.py
@@ -44,8 +44,7 @@ def enable_v2_behavior():
   tensor_shape.enable_v2_tensorshape()  # Also switched by tf2
   variable_scope.enable_resource_variables()
   # Enables TensorArrayV2 and control flow V2.
-  # TODO(b/134181885): Re-enable this.
-  # control_flow_v2_toggles.enable_control_flow_v2()
+  control_flow_v2_toggles.enable_control_flow_v2()
 
 
 @tf_export(v1=["disable_v2_behavior"])
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 9260af4fa01..d609100f2a1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -574,6 +574,7 @@ py_test(
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_v2_toggles",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
index 0932a25488a..8f059c41532 100644
--- a/tensorflow/python/data/experimental/kernel_tests/scan_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import tensor_array_ops
@@ -156,6 +157,9 @@ class ScanTest(test_base.DatasetTestBase):
 
   def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
 
+    if control_flow_v2_toggles.control_flow_v2_enabled():
+      self.skipTest("v1 only test")
+
     empty_ta = tensor_array_ops.TensorArray(
         size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
 
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 55ff8ae8964..7f642245296 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -1126,7 +1126,7 @@ tf_py_test(
 
 tf_py_test(
     name = "wrappers_test",
-    size = "medium",
+    size = "large",
     srcs = ["layers/wrappers_test.py"],
     additional_deps = [
         ":keras",
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 005d17511bb..a32f33f2fac 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -39,6 +39,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_v2_toggles
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
@@ -1082,6 +1083,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
   @test_util.disable_xla("Wants RunMetadata")
   def testParallelExecution(self):
     """Verify disjoint branches across while iterations are run in parallel."""
+    if control_flow_v2_toggles.control_flow_v2_enabled():
+      self.skipTest("b/138870290")
     if test.is_built_with_rocm():
       self.skipTest(
           "Disable subtest on ROCm due to missing Cholesky op support")
diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py
index a2e8a65a309..0f984189aef 100644
--- a/tensorflow/python/ops/control_flow_util.py
+++ b/tensorflow/python/ops/control_flow_util.py
@@ -26,9 +26,12 @@ from __future__ import print_function
 import os
 import traceback
 
+from tensorflow.python import tf2
 from tensorflow.python.platform import tf_logging as logging
 
-ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
+ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
+                           os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
+                          os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
                           os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
diff --git a/tensorflow/python/ops/control_flow_v2_disable_test.py b/tensorflow/python/ops/control_flow_v2_disable_test.py
new file mode 100644
index 00000000000..f6e3888a84c
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_disable_test.py
@@ -0,0 +1,39 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that TF2_BEHAVIOR=1 and TF_ENABLE_CONTROL_FLOW_V2=0 disables cfv2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+os.environ["TF2_BEHAVIOR"] = "1"
+os.environ["TF_ENABLE_CONTROL_FLOW_V2"] = "0"
+
+from tensorflow.python import tf2  # pylint: disable=g-import-not-at-top
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ControlFlowV2DisableTest(test.TestCase):
+
+  def testIsDisabled(self):
+    self.assertTrue(tf2.enabled())
+    self.assertFalse(control_flow_util.ENABLE_CONTROL_FLOW_V2)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/ops/control_flow_v2_enable_test.py b/tensorflow/python/ops/control_flow_v2_enable_test.py
new file mode 100644
index 00000000000..f29d4dc4a21
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_v2_enable_test.py
@@ -0,0 +1,38 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that TF2_BEHAVIOR=1 enables cfv2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+os.environ["TF2_BEHAVIOR"] = "1"
+
+from tensorflow.python import tf2  # pylint: disable=g-import-not-at-top
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+class ControlFlowV2EnableTest(test.TestCase):
+
+  def testIsEnabled(self):
+    self.assertTrue(tf2.enabled())
+    self.assertTrue(control_flow_util.ENABLE_CONTROL_FLOW_V2)
+
+
+if __name__ == "__main__":
+  googletest.main()