Enable v2 control flow inside graph_mode in TF 2.0.

Current behavior is to use v2 control flow only inside tf.function.

PiperOrigin-RevId: 261435951
This commit is contained in:
Saurabh Saxena 2019-08-02 19:58:31 -07:00 committed by TensorFlower Gardener
parent 326a74c721
commit 19518ef98e
9 changed files with 116 additions and 4 deletions

View File

@ -2609,6 +2609,30 @@ tf_py_test(
],
)
tf_py_test(
name = "control_flow_v2_enable_test",
size = "small",
srcs = ["ops/control_flow_v2_enable_test.py"],
additional_deps = [
":tf2",
":control_flow_util",
":client_testlib",
":platform_test",
],
)
tf_py_test(
name = "control_flow_v2_disable_test",
size = "small",
srcs = ["ops/control_flow_v2_disable_test.py"],
additional_deps = [
":tf2",
":control_flow_util",
":client_testlib",
":platform_test",
],
)
py_library(
name = "cond_v2",
srcs = [
@ -3926,6 +3950,7 @@ cuda_py_test(
":array_ops",
":cond_v2",
":control_flow_ops",
":control_flow_v2_toggles",
":embedding_ops",
":framework_for_generated_wrappers",
":framework_test_lib",

View File

@ -44,8 +44,7 @@ def enable_v2_behavior():
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
variable_scope.enable_resource_variables()
# Enables TensorArrayV2 and control flow V2.
# TODO(b/134181885): Re-enable this.
# control_flow_v2_toggles.enable_control_flow_v2()
control_flow_v2_toggles.enable_control_flow_v2()
@tf_export(v1=["disable_v2_behavior"])

View File

@ -574,6 +574,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_v2_toggles",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import tensor_array_ops
@ -156,6 +157,9 @@ class ScanTest(test_base.DatasetTestBase):
def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
if control_flow_v2_toggles.control_flow_v2_enabled():
self.skipTest("v1 only test")
empty_ta = tensor_array_ops.TensorArray(
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)

View File

@ -1126,7 +1126,7 @@ tf_py_test(
tf_py_test(
name = "wrappers_test",
size = "medium",
size = "large",
srcs = ["layers/wrappers_test.py"],
additional_deps = [
":keras",

View File

@ -39,6 +39,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
@ -1082,6 +1083,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@test_util.disable_xla("Wants RunMetadata")
def testParallelExecution(self):
"""Verify disjoint branches across while iterations are run in parallel."""
if control_flow_v2_toggles.control_flow_v2_enabled():
self.skipTest("b/138870290")
if test.is_built_with_rocm():
self.skipTest(
"Disable subtest on ROCm due to missing Cholesky op support")

View File

@ -26,9 +26,12 @@ from __future__ import print_function
import os
import traceback
from tensorflow.python import tf2
from tensorflow.python.platform import tf_logging as logging
ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")

View File

@ -0,0 +1,39 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that TF2_BEHAVIOR=1 and TF_ENABLE_CONTROL_FLOW_V2=0 disables cfv2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["TF2_BEHAVIOR"] = "1"
os.environ["TF_ENABLE_CONTROL_FLOW_V2"] = "0"
from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
class ControlFlowV2DisableTest(test.TestCase):
def testIsDisabled(self):
self.assertTrue(tf2.enabled())
self.assertFalse(control_flow_util.ENABLE_CONTROL_FLOW_V2)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,38 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests that TF2_BEHAVIOR=1 enables cfv2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
os.environ["TF2_BEHAVIOR"] = "1"
from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top
from tensorflow.python.ops import control_flow_util
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
class ControlFlowV2EnableTest(test.TestCase):
def testIsEnabled(self):
self.assertTrue(tf2.enabled())
self.assertTrue(control_flow_util.ENABLE_CONTROL_FLOW_V2)
if __name__ == "__main__":
googletest.main()