diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 55bb8a2c116..68dba890e70 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2609,6 +2609,30 @@ tf_py_test( ], ) +tf_py_test( + name = "control_flow_v2_enable_test", + size = "small", + srcs = ["ops/control_flow_v2_enable_test.py"], + additional_deps = [ + ":tf2", + ":control_flow_util", + ":client_testlib", + ":platform_test", + ], +) + +tf_py_test( + name = "control_flow_v2_disable_test", + size = "small", + srcs = ["ops/control_flow_v2_disable_test.py"], + additional_deps = [ + ":tf2", + ":control_flow_util", + ":client_testlib", + ":platform_test", + ], +) + py_library( name = "cond_v2", srcs = [ @@ -3926,6 +3950,7 @@ cuda_py_test( ":array_ops", ":cond_v2", ":control_flow_ops", + ":control_flow_v2_toggles", ":embedding_ops", ":framework_for_generated_wrappers", ":framework_test_lib", diff --git a/tensorflow/python/compat/v2_compat.py b/tensorflow/python/compat/v2_compat.py index 85381089b7c..e2c4c6a4316 100644 --- a/tensorflow/python/compat/v2_compat.py +++ b/tensorflow/python/compat/v2_compat.py @@ -44,8 +44,7 @@ def enable_v2_behavior(): tensor_shape.enable_v2_tensorshape() # Also switched by tf2 variable_scope.enable_resource_variables() # Enables TensorArrayV2 and control flow V2. - # TODO(b/134181885): Re-enable this. - # control_flow_v2_toggles.enable_control_flow_v2() + control_flow_v2_toggles.enable_control_flow_v2() @tf_export(v1=["disable_v2_behavior"]) diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 9260af4fa01..d609100f2a1 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -574,6 +574,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_v2_toggles", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py index 0932a25488a..8f059c41532 100644 --- a/tensorflow/python/data/experimental/kernel_tests/scan_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import tensor_array_ops @@ -156,6 +157,9 @@ class ScanTest(test_base.DatasetTestBase): def testTensorArrayWithCondResetByExternalCaptureBreaks(self): + if control_flow_v2_toggles.control_flow_v2_enabled(): + self.skipTest("v1 only test") + empty_ta = tensor_array_ops.TensorArray( size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 55ff8ae8964..7f642245296 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1126,7 +1126,7 @@ tf_py_test( tf_py_test( name = "wrappers_test", - size = "medium", + size = "large", srcs = ["layers/wrappers_test.py"], additional_deps = [ ":keras", diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 005d17511bb..a32f33f2fac 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -39,6 +39,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -1082,6 +1083,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase): @test_util.disable_xla("Wants RunMetadata") def testParallelExecution(self): """Verify disjoint branches across while iterations are run in parallel.""" + if control_flow_v2_toggles.control_flow_v2_enabled(): + self.skipTest("b/138870290") if test.is_built_with_rocm(): self.skipTest( "Disable subtest on ROCm due to missing Cholesky op support") diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index a2e8a65a309..0f984189aef 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -26,9 +26,12 @@ from __future__ import print_function import os import traceback +from tensorflow.python import tf2 from tensorflow.python.platform import tf_logging as logging -ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or +ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and + os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or + os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or os.getenv("TF_ENABLE_COND_V2", "0") != "0" or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") diff --git a/tensorflow/python/ops/control_flow_v2_disable_test.py b/tensorflow/python/ops/control_flow_v2_disable_test.py new file mode 100644 index 00000000000..f6e3888a84c --- /dev/null +++ b/tensorflow/python/ops/control_flow_v2_disable_test.py @@ -0,0 +1,39 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that TF2_BEHAVIOR=1 and TF_ENABLE_CONTROL_FLOW_V2=0 disables cfv2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +os.environ["TF2_BEHAVIOR"] = "1" +os.environ["TF_ENABLE_CONTROL_FLOW_V2"] = "0" + +from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top +from tensorflow.python.ops import control_flow_util +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class ControlFlowV2DisableTest(test.TestCase): + + def testIsDisabled(self): + self.assertTrue(tf2.enabled()) + self.assertFalse(control_flow_util.ENABLE_CONTROL_FLOW_V2) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/ops/control_flow_v2_enable_test.py b/tensorflow/python/ops/control_flow_v2_enable_test.py new file mode 100644 index 00000000000..f29d4dc4a21 --- /dev/null +++ b/tensorflow/python/ops/control_flow_v2_enable_test.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests that TF2_BEHAVIOR=1 enables cfv2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +os.environ["TF2_BEHAVIOR"] = "1" + +from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top +from tensorflow.python.ops import control_flow_util +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test + + +class ControlFlowV2EnableTest(test.TestCase): + + def testIsEnabled(self): + self.assertTrue(tf2.enabled()) + self.assertTrue(control_flow_util.ENABLE_CONTROL_FLOW_V2) + + +if __name__ == "__main__": + googletest.main()