Enable v2 control flow inside graph_mode in TF 2.0.
Current behavior is to use v2 control flow only inside tf.function. PiperOrigin-RevId: 261435951
This commit is contained in:
parent
326a74c721
commit
19518ef98e
tensorflow/python
@ -2609,6 +2609,30 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "control_flow_v2_enable_test",
|
||||
size = "small",
|
||||
srcs = ["ops/control_flow_v2_enable_test.py"],
|
||||
additional_deps = [
|
||||
":tf2",
|
||||
":control_flow_util",
|
||||
":client_testlib",
|
||||
":platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "control_flow_v2_disable_test",
|
||||
size = "small",
|
||||
srcs = ["ops/control_flow_v2_disable_test.py"],
|
||||
additional_deps = [
|
||||
":tf2",
|
||||
":control_flow_util",
|
||||
":client_testlib",
|
||||
":platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cond_v2",
|
||||
srcs = [
|
||||
@ -3926,6 +3950,7 @@ cuda_py_test(
|
||||
":array_ops",
|
||||
":cond_v2",
|
||||
":control_flow_ops",
|
||||
":control_flow_v2_toggles",
|
||||
":embedding_ops",
|
||||
":framework_for_generated_wrappers",
|
||||
":framework_test_lib",
|
||||
|
@ -44,8 +44,7 @@ def enable_v2_behavior():
|
||||
tensor_shape.enable_v2_tensorshape() # Also switched by tf2
|
||||
variable_scope.enable_resource_variables()
|
||||
# Enables TensorArrayV2 and control flow V2.
|
||||
# TODO(b/134181885): Re-enable this.
|
||||
# control_flow_v2_toggles.enable_control_flow_v2()
|
||||
control_flow_v2_toggles.enable_control_flow_v2()
|
||||
|
||||
|
||||
@tf_export(v1=["disable_v2_behavior"])
|
||||
|
@ -574,6 +574,7 @@ py_test(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_v2_toggles",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
@ -156,6 +157,9 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
|
||||
def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
|
||||
|
||||
if control_flow_v2_toggles.control_flow_v2_enabled():
|
||||
self.skipTest("v1 only test")
|
||||
|
||||
empty_ta = tensor_array_ops.TensorArray(
|
||||
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
|
||||
|
||||
|
@ -1126,7 +1126,7 @@ tf_py_test(
|
||||
|
||||
tf_py_test(
|
||||
name = "wrappers_test",
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["layers/wrappers_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -1082,6 +1083,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
@test_util.disable_xla("Wants RunMetadata")
|
||||
def testParallelExecution(self):
|
||||
"""Verify disjoint branches across while iterations are run in parallel."""
|
||||
if control_flow_v2_toggles.control_flow_v2_enabled():
|
||||
self.skipTest("b/138870290")
|
||||
if test.is_built_with_rocm():
|
||||
self.skipTest(
|
||||
"Disable subtest on ROCm due to missing Cholesky op support")
|
||||
|
@ -26,9 +26,12 @@ from __future__ import print_function
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
ENABLE_CONTROL_FLOW_V2 = (os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
|
||||
ENABLE_CONTROL_FLOW_V2 = ((tf2.enabled() and
|
||||
os.getenv("TF_ENABLE_CONTROL_FLOW_V2") != "0") or
|
||||
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
|
||||
|
39
tensorflow/python/ops/control_flow_v2_disable_test.py
Normal file
39
tensorflow/python/ops/control_flow_v2_disable_test.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests that TF2_BEHAVIOR=1 and TF_ENABLE_CONTROL_FLOW_V2=0 disables cfv2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
os.environ["TF2_BEHAVIOR"] = "1"
|
||||
os.environ["TF_ENABLE_CONTROL_FLOW_V2"] = "0"
|
||||
|
||||
from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ControlFlowV2DisableTest(test.TestCase):
|
||||
|
||||
def testIsDisabled(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertFalse(control_flow_util.ENABLE_CONTROL_FLOW_V2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
38
tensorflow/python/ops/control_flow_v2_enable_test.py
Normal file
38
tensorflow/python/ops/control_flow_v2_enable_test.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests that TF2_BEHAVIOR=1 enables cfv2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
os.environ["TF2_BEHAVIOR"] = "1"
|
||||
|
||||
from tensorflow.python import tf2 # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ControlFlowV2EnableTest(test.TestCase):
|
||||
|
||||
def testIsEnabled(self):
|
||||
self.assertTrue(tf2.enabled())
|
||||
self.assertTrue(control_flow_util.ENABLE_CONTROL_FLOW_V2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
Loading…
Reference in New Issue
Block a user