Introduce consolidated ENABLE_CONTROL_FLOW_V2 flag.
The new toggle replaces ENABLE_COND_V2, ENABLE_WHILE_V2, and ENABLE_TENSOR_ARRAY_V2. This means that these can't be toggled independently anymore, notably that v1 TensorArrays can only be run with v1 loops, and v2 TensorArrays with v2 loops. This also introduces a corresponding environment variable TF_ENABLE_CONTROL_FLOW_V2. I kept the old env vars as well in case people are using them. They all flip the new single toggle now. In addition, this change removes some while_v2 code for dealing with v1 TensorArrays, since this is no longer a supported configuration. PiperOrigin-RevId: 224862245
This commit is contained in:
parent
ee418c8ee2
commit
1d54cbf4a2
@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testMapAndBatchControlFlow(self, numa_aware):
|
def testMapAndBatchControlFlow(self, numa_aware):
|
||||||
|
|
||||||
def map_fn(x):
|
def map_fn(x):
|
||||||
previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
|
previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
control_flow_ops.ENABLE_COND_V2 = True
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||||
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
|
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
|
||||||
control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
|
||||||
return return_value
|
return return_value
|
||||||
|
|
||||||
dataset = dataset_ops.Dataset.range(100).apply(
|
dataset = dataset_ops.Dataset.range(100).apply(
|
||||||
|
@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import versions
|
from tensorflow.python.framework import versions
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import script_ops
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -409,42 +408,12 @@ def enable_control_flow_v2(fn):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
|
enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||||
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
|
|
||||||
control_flow_ops.ENABLE_COND_V2 = True
|
|
||||||
control_flow_ops.ENABLE_WHILE_V2 = True
|
|
||||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
|
|
||||||
try:
|
try:
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
|
||||||
control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
|
|
||||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def enable_tensor_array_v2(fn):
|
|
||||||
"""Decorator for enabling _GraphTensorArrayV2 on a test.
|
|
||||||
|
|
||||||
Note this enables _GraphTensorArrayV2 after running the test class's
|
|
||||||
setup/teardown methods.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: the function to be wrapped
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The wrapped function
|
|
||||||
"""
|
|
||||||
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
|
|
||||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
|
|
||||||
try:
|
|
||||||
fn(*args, **kwargs)
|
|
||||||
finally:
|
|
||||||
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
|
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -493,7 +462,7 @@ def with_control_flow_v2(cls):
|
|||||||
Returns:
|
Returns:
|
||||||
cls with new test methods added
|
cls with new test methods added
|
||||||
"""
|
"""
|
||||||
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
for name, value in cls.__dict__.copy().items():
|
for name, value in cls.__dict__.copy().items():
|
||||||
|
@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gen_array_ops
|
from tensorflow.python.ops import gen_array_ops
|
||||||
@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase):
|
|||||||
v1_msg = "The two structures don't have the same nested structure"
|
v1_msg = "The two structures don't have the same nested structure"
|
||||||
v2_msg = "Outputs of true_fn and false_fn must have the same structure"
|
v2_msg = "Outputs of true_fn and false_fn must have the same structure"
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg):
|
ValueError,
|
||||||
|
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
||||||
r = control_flow_ops.cond(pred, fn1, fn2)
|
r = control_flow_ops.cond(pred, fn1, fn2)
|
||||||
self.evaluate(r)
|
self.evaluate(r)
|
||||||
|
|
||||||
@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
|
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
|
||||||
|
|
||||||
# v1 control flow gets None second derivative for some reason.
|
# v1 control flow gets None second derivative for some reason.
|
||||||
if not control_flow_ops.ENABLE_COND_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
self.assertIsNone(grad_grad)
|
self.assertIsNone(grad_grad)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
|
|
||||||
# In defuns, all prints should execute in program order.
|
# In defuns, all prints should execute in program order.
|
||||||
# This doesn't work with legacy control flow.
|
# This doesn't work with legacy control flow.
|
||||||
if control_flow_ops.ENABLE_COND_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
|
|
||||||
@eager_function.defun
|
@eager_function.defun
|
||||||
def cond():
|
def cond():
|
||||||
@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
|
|
||||||
# In defuns, all prints should execute in program order.
|
# In defuns, all prints should execute in program order.
|
||||||
# This doesn't work with legacy control flow.
|
# This doesn't work with legacy control flow.
|
||||||
if control_flow_ops.ENABLE_WHILE_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
|
|
||||||
@eager_function.defun
|
@eager_function.defun
|
||||||
def while_loop():
|
def while_loop():
|
||||||
@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
gs = gradients_impl.gradients(loop_no_xla, v)
|
gs = gradients_impl.gradients(loop_no_xla, v)
|
||||||
self.evaluate(gs) # This should execute without error.
|
self.evaluate(gs) # This should execute without error.
|
||||||
|
|
||||||
if control_flow_ops.ENABLE_WHILE_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
xla_context.Enter()
|
xla_context.Enter()
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
lambda i, x: (i + 1, v * x), (0, 1.0),
|
lambda i, x: (i + 1, v * x), (0, 1.0),
|
||||||
maximum_iterations=max_iter_holder[0])
|
maximum_iterations=max_iter_holder[0])
|
||||||
|
|
||||||
if control_flow_ops.ENABLE_WHILE_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||||
xla_context.Enter()
|
xla_context.Enter()
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
|
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
|
||||||
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
|
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
|
||||||
|
|
||||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
|
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
|
||||||
@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
# outer_loop(x) = g(g(x)) = 4x + 81
|
# outer_loop(x) = g(g(x)) = 4x + 81
|
||||||
# outer_loop'(x) = 4
|
# outer_loop'(x) = 4
|
||||||
# Note that v1 control flow gets 4.0 as well if the cond is removed.
|
# Note that v1 control flow gets 4.0 as well if the cond is removed.
|
||||||
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
self.assertEqual(grad, 4.0)
|
self.assertEqual(grad, 4.0)
|
||||||
|
|
||||||
def testWhile_NestedInput(self):
|
def testWhile_NestedInput(self):
|
||||||
@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
|
|
||||||
result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
|
result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
|
||||||
grad_theta = gradients_impl.gradients(result, theta)
|
grad_theta = gradients_impl.gradients(result, theta)
|
||||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
|
with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
|
||||||
gradients_impl.gradients(grad_theta, theta)
|
gradients_impl.gradients(grad_theta, theta)
|
||||||
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
|
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
|
||||||
@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase):
|
|||||||
self.assertEqual(r[1].eval(), 65536.0)
|
self.assertEqual(r[1].eval(), 65536.0)
|
||||||
self.assertEqual(grad.eval(), 524288.0)
|
self.assertEqual(grad.eval(), 524288.0)
|
||||||
# while_v2 does not have stacks.
|
# while_v2 does not have stacks.
|
||||||
if not control_flow_ops.ENABLE_WHILE_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len([op for op in x.graph.get_operations() if op.type == "StackV2"
|
len([op for op in x.graph.get_operations() if op.type == "StackV2"
|
||||||
]), 1)
|
]), 1)
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.python.eager import function
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import control_flow_util_v2
|
from tensorflow.python.ops import control_flow_util_v2
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -30,14 +31,11 @@ from tensorflow.python.platform import test
|
|||||||
class ControlFlowUtilV2Test(test.TestCase):
|
class ControlFlowUtilV2Test(test.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
|
self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||||
control_flow_ops.ENABLE_COND_V2 = True
|
|
||||||
control_flow_ops.ENABLE_WHILE_V2 = True
|
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old
|
||||||
control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old
|
|
||||||
|
|
||||||
def _create_control_flow(self, expect_in_defun):
|
def _create_control_flow(self, expect_in_defun):
|
||||||
"""Helper method for testInDefun."""
|
"""Helper method for testInDefun."""
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import data_flow_ops
|
from tensorflow.python.ops import data_flow_ops
|
||||||
from tensorflow.python.ops import gen_data_flow_ops
|
from tensorflow.python.ops import gen_data_flow_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testSkipEagerTensorArrayGradGrad(self):
|
def testSkipEagerTensorArrayGradGrad(self):
|
||||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
self.skipTest("Legacy TensorArray does not support double derivatives.")
|
self.skipTest("Legacy TensorArray does not support double derivatives.")
|
||||||
with self.test_session(use_gpu=True) as session:
|
with self.test_session(use_gpu=True) as session:
|
||||||
x = constant_op.constant(4.0)
|
x = constant_op.constant(4.0)
|
||||||
@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
ta = _make_ta(3, "foo", dtype=dtypes.float32)
|
ta = _make_ta(3, "foo", dtype=dtypes.float32)
|
||||||
# Test writing the wrong datatype
|
# Test writing the wrong datatype
|
||||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
error_msg = ("Invalid data types; op elements string but list elements "
|
error_msg = ("Invalid data types; op elements string but list elements "
|
||||||
"float")
|
"float")
|
||||||
@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError(error_msg):
|
with self.assertRaisesOpError(error_msg):
|
||||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||||
|
|
||||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
error_msg = "Trying to modify element -1 in a list with 3 elements."
|
error_msg = "Trying to modify element -1 in a list with 3 elements."
|
||||||
else:
|
else:
|
||||||
@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError(error_msg):
|
with self.assertRaisesOpError(error_msg):
|
||||||
self.evaluate(ta.write(-1, 3.0).flow)
|
self.evaluate(ta.write(-1, 3.0).flow)
|
||||||
|
|
||||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
error_msg = "Trying to modify element 3 in a list with 3 elements"
|
error_msg = "Trying to modify element 3 in a list with 3 elements"
|
||||||
else:
|
else:
|
||||||
@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase):
|
|||||||
|
|
||||||
# Test reading wrong datatype (only possible when constructing graphs).
|
# Test reading wrong datatype (only possible when constructing graphs).
|
||||||
if (not context.executing_eagerly() and
|
if (not context.executing_eagerly() and
|
||||||
not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2):
|
not control_flow_util.ENABLE_CONTROL_FLOW_V2):
|
||||||
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
|
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
|
||||||
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
|
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
"TensorArray dtype is float but Op requested dtype double."):
|
"TensorArray dtype is float but Op requested dtype double."):
|
||||||
self.evaluate(r0_bad)
|
self.evaluate(r0_bad)
|
||||||
|
|
||||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
error_msg = "Trying to access element -1 in a list with 3 elements."
|
error_msg = "Trying to access element -1 in a list with 3 elements."
|
||||||
else:
|
else:
|
||||||
@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
with self.assertRaisesOpError(error_msg):
|
with self.assertRaisesOpError(error_msg):
|
||||||
self.evaluate(ta.read(-1))
|
self.evaluate(ta.read(-1))
|
||||||
|
|
||||||
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not context.executing_eagerly()):
|
not context.executing_eagerly()):
|
||||||
error_msg = "Trying to access element 3 in a list with 3 elements."
|
error_msg = "Trying to access element 3 in a list with 3 elements."
|
||||||
else:
|
else:
|
||||||
@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
|
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
|
||||||
|
|
||||||
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
|
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
|
||||||
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||||
not in_eager_mode else
|
not in_eager_mode else
|
||||||
r"Expected sum of lengths to be equal to values.shape\[0\], "
|
r"Expected sum of lengths to be equal to values.shape\[0\], "
|
||||||
r"but sum of lengths is 1 and value's shape is: \[3\]")
|
r"but sum of lengths is 1 and value's shape is: \[3\]")
|
||||||
@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
|
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
|
||||||
|
|
||||||
ta = _make_ta(1, "baz")
|
ta = _make_ta(1, "baz")
|
||||||
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, "Shape must be at least rank 1 but is rank 0"):
|
ValueError, "Shape must be at least rank 1 but is rank 0"):
|
||||||
self.evaluate(ta.split(1.0, [1]).flow)
|
self.evaluate(ta.split(1.0, [1]).flow)
|
||||||
@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
):
|
):
|
||||||
self.evaluate(ta.split(1.0, [1]).flow)
|
self.evaluate(ta.split(1.0, [1]).flow)
|
||||||
|
|
||||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
|
||||||
ta = _make_ta(2, "buz")
|
ta = _make_ta(2, "buz")
|
||||||
with self.assertRaisesOpError(
|
with self.assertRaisesOpError(
|
||||||
r"TensorArray's size is not equal to the size of lengths "
|
r"TensorArray's size is not equal to the size of lengths "
|
||||||
@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase):
|
|||||||
# self._testWhileLoopWritePackGradients(
|
# self._testWhileLoopWritePackGradients(
|
||||||
# dynamic_size=False, dtype=tf.int64)
|
# dynamic_size=False, dtype=tf.int64)
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA")
|
|
||||||
@test_util.enable_tensor_array_v2
|
|
||||||
def testWhileLoopV1WithTensorArrayV2(self):
|
|
||||||
size = 3
|
|
||||||
ta = tensor_array_ops.TensorArray(
|
|
||||||
dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar())
|
|
||||||
|
|
||||||
def Body(counter, ta):
|
|
||||||
return counter + 1, ta.write(counter, counter)
|
|
||||||
|
|
||||||
_, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta])
|
|
||||||
|
|
||||||
for i in range(size):
|
|
||||||
self.assertEqual(self.evaluate(ta.read(i)), i)
|
|
||||||
|
|
||||||
@test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
|
@test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
|
||||||
@test_util.run_v1_only("b/117943489")
|
@test_util.run_v1_only("b/117943489")
|
||||||
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
|
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
|
||||||
@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase):
|
|||||||
self.assertEqual((2, 2), w0.read(1).get_shape())
|
self.assertEqual((2, 2), w0.read(1).get_shape())
|
||||||
else:
|
else:
|
||||||
self.assertEqual(r0.get_shape().ndims, None)
|
self.assertEqual(r0.get_shape().ndims, None)
|
||||||
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
|
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
tensor_shape.TensorShape(
|
tensor_shape.TensorShape(
|
||||||
ta1.handle.op.get_attr("element_shape")).ndims, None)
|
ta1.handle.op.get_attr("element_shape")).ndims, None)
|
||||||
@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase):
|
|||||||
"TensorArray has size zero, but element shape <unknown> is not "
|
"TensorArray has size zero, but element shape <unknown> is not "
|
||||||
"fully defined. Currently only static shapes are supported when "
|
"fully defined. Currently only static shapes are supported when "
|
||||||
"packing zero-size TensorArrays.")
|
"packing zero-size TensorArrays.")
|
||||||
with self.assertRaisesOpError(v2_msg if tensor_array_ops
|
with self.assertRaisesOpError(
|
||||||
.ENABLE_TENSOR_ARRAY_V2 else v1_msg):
|
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
|
||||||
ta.stack().eval()
|
ta.stack().eval()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
|
@ -24,13 +24,11 @@ from __future__ import print_function
|
|||||||
import abc
|
import abc
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import os
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
from tensorflow.core.protobuf import control_flow_pb2
|
from tensorflow.core.protobuf import control_flow_pb2
|
||||||
from tensorflow.python import tf2
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(),
|
|||||||
while_v2 = LazyLoader("while_v2", globals(),
|
while_v2 = LazyLoader("while_v2", globals(),
|
||||||
"tensorflow.python.ops.while_v2")
|
"tensorflow.python.ops.while_v2")
|
||||||
|
|
||||||
ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0"
|
|
||||||
ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
|
|
||||||
|
|
||||||
# We override the 'tuple' for a control flow op, so we keep python's
|
# We override the 'tuple' for a control flow op, so we keep python's
|
||||||
# existing 'tuple' for later use in this module.
|
# existing 'tuple' for later use in this module.
|
||||||
_basetuple = tuple
|
_basetuple = tuple
|
||||||
@ -2052,7 +2047,7 @@ def cond(pred,
|
|||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if ENABLE_COND_V2 and not context.executing_eagerly():
|
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
|
||||||
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
|
||||||
|
|
||||||
# We needed to make true_fn/false_fn keyword arguments for
|
# We needed to make true_fn/false_fn keyword arguments for
|
||||||
@ -3487,7 +3482,7 @@ def while_loop(cond,
|
|||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if ENABLE_WHILE_V2 and not context.executing_eagerly():
|
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
|
||||||
return while_v2.while_loop(
|
return while_v2.while_loop(
|
||||||
cond,
|
cond,
|
||||||
body,
|
body,
|
||||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark):
|
|||||||
iters=self.NUM_ITERS)
|
iters=self.NUM_ITERS)
|
||||||
|
|
||||||
def benchmark_cond_v1_defun(self):
|
def benchmark_cond_v1_defun(self):
|
||||||
old_val = control_flow_ops.ENABLE_COND_V2
|
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
control_flow_ops.ENABLE_COND_V2 = False
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
|
||||||
self._benchmark_defun()
|
self._benchmark_defun()
|
||||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||||
|
|
||||||
def benchmark_cond_v2_defun(self):
|
def benchmark_cond_v2_defun(self):
|
||||||
old_val = control_flow_ops.ENABLE_COND_V2
|
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
control_flow_ops.ENABLE_COND_V2 = True
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||||
self._benchmark_defun()
|
self._benchmark_defun()
|
||||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||||
|
|
||||||
def benchmark_cond_v1_graph(self):
|
def benchmark_cond_v1_graph(self):
|
||||||
old_val = control_flow_ops.ENABLE_COND_V2
|
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
control_flow_ops.ENABLE_COND_V2 = False
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
|
||||||
self._benchmark_graph()
|
self._benchmark_graph()
|
||||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||||
|
|
||||||
def benchmark_cond_v2_graph(self):
|
def benchmark_cond_v2_graph(self):
|
||||||
old_val = control_flow_ops.ENABLE_COND_V2
|
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
|
||||||
control_flow_ops.ENABLE_COND_V2 = True
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
|
||||||
self._benchmark_graph()
|
self._benchmark_graph()
|
||||||
control_flow_ops.ENABLE_COND_V2 = old_val
|
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -23,10 +23,18 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or
|
||||||
|
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
|
||||||
|
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
|
||||||
|
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
|
||||||
|
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
|
||||||
|
|
||||||
|
|
||||||
def IsInXLAContext(op):
|
def IsInXLAContext(op):
|
||||||
try:
|
try:
|
||||||
|
@ -20,10 +20,8 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import os
|
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from tensorflow.python import tf2
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -32,6 +30,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import control_flow_util
|
||||||
from tensorflow.python.ops import gen_control_flow_ops
|
from tensorflow.python.ops import gen_control_flow_ops
|
||||||
from tensorflow.python.ops import gen_data_flow_ops
|
from tensorflow.python.ops import gen_data_flow_ops
|
||||||
from tensorflow.python.ops import list_ops
|
from tensorflow.python.ops import list_ops
|
||||||
@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
ENABLE_TENSOR_ARRAY_V2 = (
|
|
||||||
tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None)
|
|
||||||
|
|
||||||
|
|
||||||
# _GraphTensorArray accesses many of the hidden generated ops, but is in
|
# _GraphTensorArray accesses many of the hidden generated ops, but is in
|
||||||
# fact built to wrap these methods.
|
# fact built to wrap these methods.
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
@ -1013,7 +1008,7 @@ class TensorArray(object):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
implementation = _EagerTensorArray
|
implementation = _EagerTensorArray
|
||||||
else:
|
else:
|
||||||
if ENABLE_TENSOR_ARRAY_V2:
|
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
|
||||||
implementation = _GraphTensorArrayV2
|
implementation = _GraphTensorArrayV2
|
||||||
else:
|
else:
|
||||||
implementation = _GraphTensorArray
|
implementation = _GraphTensorArray
|
||||||
|
@ -52,13 +52,6 @@ from tensorflow.python.util import nest
|
|||||||
# to them and then pass those in as data inputs. This should probably be
|
# to them and then pass those in as data inputs. This should probably be
|
||||||
# handled in the CapturingGraph itself.
|
# handled in the CapturingGraph itself.
|
||||||
|
|
||||||
# Op types that output a resource tensor representing a TensorArray handle.
|
|
||||||
TENSOR_ARRAY_HANDLE_OPS = (
|
|
||||||
"TensorArrayV3",
|
|
||||||
"TensorArrayGradV3",
|
|
||||||
"TensorArrayGradWithShape",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def while_loop(cond,
|
def while_loop(cond,
|
||||||
body,
|
body,
|
||||||
@ -257,24 +250,19 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
"_maximum_iterations") if _is_in_xla_context() else None
|
"_maximum_iterations") if _is_in_xla_context() else None
|
||||||
assert not _is_in_xla_context() or maximum_iterations is not None
|
assert not _is_in_xla_context() or maximum_iterations is not None
|
||||||
|
|
||||||
# Set the incoming gradient of TensorArray handles to None. The gradient
|
# Set the incoming gradient of non-trainable inputs to None. It is possible
|
||||||
# implementation currently assumes all resource tensors correspond to float32
|
# that we receive non-None gradients for non-trainable types in nested while
|
||||||
# ResourceVariables, which can lead to runtime shape errors when used with a
|
# loops because we accumulate outputs of the inner while as variant tensors
|
||||||
# TensorArray. This is a workaround until TensorArrays are reimplemented with
|
# which are trainable and hence receive zeros_like tensors in the gradient
|
||||||
# TensorLists instead of resources.
|
# pass. The non-trainable tensors then receive the popped zeros tensor from
|
||||||
# Also set the incoming gradient of non-trainable inputs to None. It is
|
# this zeros variant. The gradient for the loop vars corresponding to these
|
||||||
# possible that we receive non-None gradients for non-trainable types in
|
# tensors is None or zeros (this happens only if the loop var is accumulated
|
||||||
# nested while loops because we accumulate outputs of the inner while as
|
# as well) in _grad_fn so we reset these.
|
||||||
# variant tensors which are trainable and hence receive zeros_like tensors in
|
|
||||||
# the gradient pass. The non-trainable tensors then receive the popped zeros
|
|
||||||
# tensor from this zeros variant. The gradient for the loop vars corresponding
|
|
||||||
# to these tensors is None or zeros (this happens only if the loop var is
|
|
||||||
# accumulated as well) in _grad_fn so we reset these.
|
|
||||||
# TODO(b/118712257): Remove the IsTrainable filter once we can handle None
|
# TODO(b/118712257): Remove the IsTrainable filter once we can handle None
|
||||||
# output grads in _grad_fn.
|
# output grads in _grad_fn.
|
||||||
grads = [
|
grads = [
|
||||||
None if _is_tensor_array_handle(output) or not _is_trainable(output)
|
None if not _is_trainable(output) else grad
|
||||||
else grad for grad, output in zip(grads, body_graph.outputs)
|
for grad, output in zip(grads, body_graph.outputs)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ensure that all non-resource trainable outputs have incoming gradients.
|
# Ensure that all non-resource trainable outputs have incoming gradients.
|
||||||
@ -339,8 +327,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
|
|||||||
# See comment in while_loop.
|
# See comment in while_loop.
|
||||||
outputs = [array_ops.identity(t) for t in outputs]
|
outputs = [array_ops.identity(t) for t in outputs]
|
||||||
|
|
||||||
# Set None as the output gradient for tensors with None input gradient
|
# Set None as the output gradient for tensors with None input gradient.
|
||||||
# e.g. TensorArray handles.
|
|
||||||
# outputs[0] is the loop counter.
|
# outputs[0] is the loop counter.
|
||||||
# outputs[1] is the total number of loop iterations.
|
# outputs[1] is the total number of loop iterations.
|
||||||
index = 2
|
index = 2
|
||||||
@ -853,28 +840,6 @@ def _graph_name(graph):
|
|||||||
return "Base"
|
return "Base"
|
||||||
|
|
||||||
|
|
||||||
def _is_tensor_array_handle(tensor):
|
|
||||||
"""Returns whether tensor is a TensorArray handle."""
|
|
||||||
if tensor.dtype != dtypes.resource:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if tensor.op.type == "While":
|
|
||||||
# We assume that any resource outputs of a While op correspond to a captured
|
|
||||||
# resource input (as opposed to a loop variable specified by the user).
|
|
||||||
# NOTE(skyewm): we could actually check this, but I can't think of when you
|
|
||||||
# would have a resource loop variable.
|
|
||||||
tensor = tensor.op.inputs[tensor.value_index]
|
|
||||||
|
|
||||||
# TODO(b/118452219): add test coverage for this.
|
|
||||||
tensor = func_graph_module.maybe_captured(tensor)
|
|
||||||
|
|
||||||
if isinstance(tensor, ops.EagerTensor):
|
|
||||||
# Eager execution doesn't quite support legacy tensorarray
|
|
||||||
return False
|
|
||||||
|
|
||||||
return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
|
|
||||||
|
|
||||||
|
|
||||||
def _pack_sequence_as(structure_with_tas, loop_vars):
|
def _pack_sequence_as(structure_with_tas, loop_vars):
|
||||||
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
|
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user