Introduce consolidated ENABLE_CONTROL_FLOW_V2 flag.
The new toggle replaces ENABLE_COND_V2, ENABLE_WHILE_V2, and ENABLE_TENSOR_ARRAY_V2. This means that these can't be toggled independently anymore, notably that v1 TensorArrays can only be run with v1 loops, and v2 TensorArrays with v2 loops. This also introduces a corresponding environment variable TF_ENABLE_CONTROL_FLOW_V2. I kept the old env vars as well in case people are using them. They all flip the new single toggle now. In addition, this change removes some while_v2 code for dealing with v1 TensorArrays, since this is no longer a supported configuration. PiperOrigin-RevId: 224862245
This commit is contained in:
parent
ee418c8ee2
commit
1d54cbf4a2
tensorflow/python
data/experimental/kernel_tests
framework
kernel_tests
ops
@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testMapAndBatchControlFlow(self, numa_aware):
|
||||
|
||||
def map_fn(x):
|
||||
previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
|
||||
control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
|
||||
return return_value
|
||||
|
||||
dataset = dataset_ops.Dataset.range(100).apply(
|
||||
|
@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -409,42 +408,12 @@ def enable_control_flow_v2(fn):
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
|
||||
enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
|
||||
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
control_flow_ops.ENABLE_WHILE_V2 = True
|
||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
|
||||
enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||
try:
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
|
||||
control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
|
||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def enable_tensor_array_v2(fn):
|
||||
"""Decorator for enabling _GraphTensorArrayV2 on a test.
|
||||
|
||||
Note this enables _GraphTensorArrayV2 after running the test class's
|
||||
setup/teardown methods.
|
||||
|
||||
Args:
|
||||
fn: the function to be wrapped
|
||||
|
||||
Returns:
|
||||
The wrapped function
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
|
||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
|
||||
try:
|
||||
fn(*args, **kwargs)
|
||||
finally:
|
||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -493,7 +462,7 @@ def with_control_flow_v2(cls):
|
||||
Returns:
|
||||
cls with new test methods added
|
||||
"""
|
||||
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
return cls
|
||||
|
||||
for name, value in cls.__dict__.copy().items():
|
||||
|
@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase):
|
||||
v1_msg = "The two structures don't have the same nested structure"
|
||||
v2_msg = "Outputs of true_fn and false_fn must have the same structure"
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg):
|
||||
ValueError,
|
||||
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
||||
r = control_flow_ops.cond(pred, fn1, fn2)
|
||||
self.evaluate(r)
|
||||
|
||||
@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
|
||||
|
||||
# v1 control flow gets None second derivative for some reason.
|
||||
if not control_flow_ops.ENABLE_COND_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.assertIsNone(grad_grad)
|
||||
return
|
||||
|
||||
@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
# In defuns, all prints should execute in program order.
|
||||
# This doesn't work with legacy control flow.
|
||||
if control_flow_ops.ENABLE_COND_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
|
||||
@eager_function.defun
|
||||
def cond():
|
||||
@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
# In defuns, all prints should execute in program order.
|
||||
# This doesn't work with legacy control flow.
|
||||
if control_flow_ops.ENABLE_WHILE_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
|
||||
@eager_function.defun
|
||||
def while_loop():
|
||||
@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase):
|
||||
gs = gradients_impl.gradients(loop_no_xla, v)
|
||||
self.evaluate(gs) # This should execute without error.
|
||||
|
||||
if control_flow_ops.ENABLE_WHILE_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase):
|
||||
lambda i, x: (i + 1, v * x), (0, 1.0),
|
||||
maximum_iterations=max_iter_holder[0])
|
||||
|
||||
if control_flow_ops.ENABLE_WHILE_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
with self.assertRaisesRegexp(
|
||||
@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
|
||||
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
|
||||
|
||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
return
|
||||
|
||||
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
|
||||
@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase):
|
||||
# outer_loop(x) = g(g(x)) = 4x + 81
|
||||
# outer_loop'(x) = 4
|
||||
# Note that v1 control flow gets 4.0 as well if the cond is removed.
|
||||
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.assertEqual(grad, 4.0)
|
||||
|
||||
def testWhile_NestedInput(self):
|
||||
@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase):
|
||||
|
||||
result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
|
||||
grad_theta = gradients_impl.gradients(result, theta)
|
||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
|
||||
gradients_impl.gradients(grad_theta, theta)
|
||||
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
|
||||
@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase):
|
||||
self.assertEqual(r[1].eval(), 65536.0)
|
||||
self.assertEqual(grad.eval(), 524288.0)
|
||||
# while_v2 does not have stacks.
|
||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.assertEqual(
|
||||
len([op for op in x.graph.get_operations() if op.type == "StackV2"
|
||||
]), 1)
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -30,14 +31,11 @@ from tensorflow.python.platform import test
|
||||
class ControlFlowUtilV2Test(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
|
||||
self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
control_flow_ops.ENABLE_WHILE_V2 = True
|
||||
self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||
|
||||
def tearDown(self):
|
||||
control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old
|
||||
control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old
|
||||
|
||||
def _create_control_flow(self, expect_in_defun):
|
||||
"""Helper method for testInDefun."""
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSkipEagerTensorArrayGradGrad(self):
|
||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.skipTest("Legacy TensorArray does not support double derivatives.")
|
||||
with self.test_session(use_gpu=True) as session:
|
||||
x = constant_op.constant(4.0)
|
||||
@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.session(use_gpu=True):
|
||||
ta = _make_ta(3, "foo", dtype=dtypes.float32)
|
||||
# Test writing the wrong datatype
|
||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = ("Invalid data types; op elements string but list elements "
|
||||
"float")
|
||||
@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(error_msg):
|
||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||
|
||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = "Trying to modify element -1 in a list with 3 elements."
|
||||
else:
|
||||
@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(error_msg):
|
||||
self.evaluate(ta.write(-1, 3.0).flow)
|
||||
|
||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = "Trying to modify element 3 in a list with 3 elements"
|
||||
else:
|
||||
@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase):
|
||||
|
||||
# Test reading wrong datatype (only possible when constructing graphs).
|
||||
if (not context.executing_eagerly() and
|
||||
not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2):
|
||||
not control_flow_util.ENABLE_CONTROL_FLOW_V2):
|
||||
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
|
||||
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but Op requested dtype double."):
|
||||
self.evaluate(r0_bad)
|
||||
|
||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = "Trying to access element -1 in a list with 3 elements."
|
||||
else:
|
||||
@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase):
|
||||
with self.assertRaisesOpError(error_msg):
|
||||
self.evaluate(ta.read(-1))
|
||||
|
||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = "Trying to access element 3 in a list with 3 elements."
|
||||
else:
|
||||
@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase):
|
||||
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
|
||||
|
||||
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
|
||||
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not in_eager_mode else
|
||||
r"Expected sum of lengths to be equal to values.shape\[0\], "
|
||||
r"but sum of lengths is 1 and value's shape is: \[3\]")
|
||||
@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase):
|
||||
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
|
||||
|
||||
ta = _make_ta(1, "baz")
|
||||
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Shape must be at least rank 1 but is rank 0"):
|
||||
self.evaluate(ta.split(1.0, [1]).flow)
|
||||
@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase):
|
||||
):
|
||||
self.evaluate(ta.split(1.0, [1]).flow)
|
||||
|
||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
|
||||
ta = _make_ta(2, "buz")
|
||||
with self.assertRaisesOpError(
|
||||
r"TensorArray's size is not equal to the size of lengths "
|
||||
@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase):
|
||||
# self._testWhileLoopWritePackGradients(
|
||||
# dynamic_size=False, dtype=tf.int64)
|
||||
|
||||
@test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA")
|
||||
@test_util.enable_tensor_array_v2
|
||||
def testWhileLoopV1WithTensorArrayV2(self):
|
||||
size = 3
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar())
|
||||
|
||||
def Body(counter, ta):
|
||||
return counter + 1, ta.write(counter, counter)
|
||||
|
||||
_, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta])
|
||||
|
||||
for i in range(size):
|
||||
self.assertEqual(self.evaluate(ta.read(i)), i)
|
||||
|
||||
@test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
|
||||
@test_util.run_v1_only("b/117943489")
|
||||
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
|
||||
@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase):
|
||||
self.assertEqual((2, 2), w0.read(1).get_shape())
|
||||
else:
|
||||
self.assertEqual(r0.get_shape().ndims, None)
|
||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
|
||||
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
self.assertEqual(
|
||||
tensor_shape.TensorShape(
|
||||
ta1.handle.op.get_attr("element_shape")).ndims, None)
|
||||
@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase):
|
||||
"TensorArray has size zero, but element shape <unknown> is not "
|
||||
"fully defined. Currently only static shapes are supported when "
|
||||
"packing zero-size TensorArrays.")
|
||||
with self.assertRaisesOpError(v2_msg if tensor_array_ops
|
||||
.ENABLE_TENSOR_ARRAY_V2 else v1_msg):
|
||||
with self.assertRaisesOpError(
|
||||
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
||||
ta.stack().eval()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
|
@ -24,13 +24,11 @@ from __future__ import print_function
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf import control_flow_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(),
|
||||
while_v2 = LazyLoader("while_v2", globals(),
|
||||
"tensorflow.python.ops.while_v2")
|
||||
|
||||
ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0"
|
||||
ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
|
||||
|
||||
# We override the 'tuple' for a control flow op, so we keep python's
|
||||
# existing 'tuple' for later use in this module.
|
||||
_basetuple = tuple
|
||||
@ -2052,7 +2047,7 @@ def cond(pred,
|
||||
```
|
||||
|
||||
"""
|
||||
if ENABLE_COND_V2 and not context.executing_eagerly():
|
||||
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
|
||||
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
||||
|
||||
# We needed to make true_fn/false_fn keyword arguments for
|
||||
@ -3487,7 +3482,7 @@ def while_loop(cond,
|
||||
```
|
||||
|
||||
"""
|
||||
if ENABLE_WHILE_V2 and not context.executing_eagerly():
|
||||
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
|
||||
return while_v2.while_loop(
|
||||
cond,
|
||||
body,
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark):
|
||||
iters=self.NUM_ITERS)
|
||||
|
||||
def benchmark_cond_v1_defun(self):
|
||||
old_val = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = False
|
||||
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
|
||||
self._benchmark_defun()
|
||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||
|
||||
def benchmark_cond_v2_defun(self):
|
||||
old_val = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||
self._benchmark_defun()
|
||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||
|
||||
def benchmark_cond_v1_graph(self):
|
||||
old_val = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = False
|
||||
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
|
||||
self._benchmark_graph()
|
||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||
|
||||
def benchmark_cond_v2_graph(self):
|
||||
old_val = control_flow_ops.ENABLE_COND_V2
|
||||
control_flow_ops.ENABLE_COND_V2 = True
|
||||
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||
self._benchmark_graph()
|
||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
||||
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
|
@ -23,10 +23,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or
|
||||
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
|
||||
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
|
||||
|
||||
|
||||
def IsInXLAContext(op):
|
||||
try:
|
||||
|
@ -20,10 +20,8 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -32,6 +30,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
ENABLE_TENSOR_ARRAY_V2 = (
|
||||
tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None)
|
||||
|
||||
|
||||
# _GraphTensorArray accesses many of the hidden generated ops, but is in
|
||||
# fact built to wrap these methods.
|
||||
# pylint: disable=protected-access
|
||||
@ -1013,7 +1008,7 @@ class TensorArray(object):
|
||||
if context.executing_eagerly():
|
||||
implementation = _EagerTensorArray
|
||||
else:
|
||||
if ENABLE_TENSOR_ARRAY_V2:
|
||||
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||
implementation = _GraphTensorArrayV2
|
||||
else:
|
||||
implementation = _GraphTensorArray
|
||||
|
@ -52,13 +52,6 @@ from tensorflow.python.util import nest
|
||||
# to them and then pass those in as data inputs. This should probably be
|
||||
# handled in the CapturingGraph itself.
|
||||
|
||||
# Op types that output a resource tensor representing a TensorArray handle.
|
||||
TENSOR_ARRAY_HANDLE_OPS = (
|
||||
"TensorArrayV3",
|
||||
"TensorArrayGradV3",
|
||||
"TensorArrayGradWithShape",
|
||||
)
|
||||
|
||||
|
||||
def while_loop(cond,
|
||||
body,
|
||||
@ -257,24 +250,19 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
"_maximum_iterations") if _is_in_xla_context() else None
|
||||
assert not _is_in_xla_context() or maximum_iterations is not None
|
||||
|
||||
# Set the incoming gradient of TensorArray handles to None. The gradient
|
||||
# implementation currently assumes all resource tensors correspond to float32
|
||||
# ResourceVariables, which can lead to runtime shape errors when used with a
|
||||
# TensorArray. This is a workaround until TensorArrays are reimplemented with
|
||||
# TensorLists instead of resources.
|
||||
# Also set the incoming gradient of non-trainable inputs to None. It is
|
||||
# possible that we receive non-None gradients for non-trainable types in
|
||||
# nested while loops because we accumulate outputs of the inner while as
|
||||
# variant tensors which are trainable and hence receive zeros_like tensors in
|
||||
# the gradient pass. The non-trainable tensors then receive the popped zeros
|
||||
# tensor from this zeros variant. The gradient for the loop vars corresponding
|
||||
# to these tensors is None or zeros (this happens only if the loop var is
|
||||
# accumulated as well) in _grad_fn so we reset these.
|
||||
# Set the incoming gradient of non-trainable inputs to None. It is possible
|
||||
# that we receive non-None gradients for non-trainable types in nested while
|
||||
# loops because we accumulate outputs of the inner while as variant tensors
|
||||
# which are trainable and hence receive zeros_like tensors in the gradient
|
||||
# pass. The non-trainable tensors then receive the popped zeros tensor from
|
||||
# this zeros variant. The gradient for the loop vars corresponding to these
|
||||
# tensors is None or zeros (this happens only if the loop var is accumulated
|
||||
# as well) in _grad_fn so we reset these.
|
||||
# TODO(b/118712257): Remove the IsTrainable filter once we can handle None
|
||||
# output grads in _grad_fn.
|
||||
grads = [
|
||||
None if _is_tensor_array_handle(output) or not _is_trainable(output)
|
||||
else grad for grad, output in zip(grads, body_graph.outputs)
|
||||
None if not _is_trainable(output) else grad
|
||||
for grad, output in zip(grads, body_graph.outputs)
|
||||
]
|
||||
|
||||
# Ensure that all non-resource trainable outputs have incoming gradients.
|
||||
@ -339,8 +327,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
||||
# See comment in while_loop.
|
||||
outputs = [array_ops.identity(t) for t in outputs]
|
||||
|
||||
# Set None as the output gradient for tensors with None input gradient
|
||||
# e.g. TensorArray handles.
|
||||
# Set None as the output gradient for tensors with None input gradient.
|
||||
# outputs[0] is the loop counter.
|
||||
# outputs[1] is the total number of loop iterations.
|
||||
index = 2
|
||||
@ -853,28 +840,6 @@ def _graph_name(graph):
|
||||
return "Base"
|
||||
|
||||
|
||||
def _is_tensor_array_handle(tensor):
|
||||
"""Returns whether tensor is a TensorArray handle."""
|
||||
if tensor.dtype != dtypes.resource:
|
||||
return False
|
||||
|
||||
if tensor.op.type == "While":
|
||||
# We assume that any resource outputs of a While op correspond to a captured
|
||||
# resource input (as opposed to a loop variable specified by the user).
|
||||
# NOTE(skyewm): we could actually check this, but I can't think of when you
|
||||
# would have a resource loop variable.
|
||||
tensor = tensor.op.inputs[tensor.value_index]
|
||||
|
||||
# TODO(b/118452219): add test coverage for this.
|
||||
tensor = func_graph_module.maybe_captured(tensor)
|
||||
|
||||
if isinstance(tensor, ops.EagerTensor):
|
||||
# Eager execution doesn't quite support legacy tensorarray
|
||||
return False
|
||||
|
||||
return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
|
||||
|
||||
|
||||
def _pack_sequence_as(structure_with_tas, loop_vars):
|
||||
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user