Introduce consolidated ENABLE_CONTROL_FLOW_V2 flag.

The new toggle replaces ENABLE_COND_V2, ENABLE_WHILE_V2, and
ENABLE_TENSOR_ARRAY_V2. This means that these can't be toggled
independently anymore, notably that v1 TensorArrays can only be run
with v1 loops, and v2 TensorArrays with v2 loops.

This also introduces a corresponding environment variable
TF_ENABLE_CONTROL_FLOW_V2. I kept the old env vars as well in case
people are using them. They all flip the new single toggle now.

In addition, this change removes some while_v2 code for dealing with
v1 TensorArrays, since this is no longer a supported configuration.

PiperOrigin-RevId: 224862245
This commit is contained in:
Skye Wanderman-Milne 2018-12-10 12:35:24 -08:00 committed by TensorFlower Gardener
parent ee418c8ee2
commit 1d54cbf4a2
10 changed files with 75 additions and 155 deletions

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def testMapAndBatchControlFlow(self, numa_aware): def testMapAndBatchControlFlow(self, numa_aware):
def map_fn(x): def map_fn(x):
previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2 previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_ops.ENABLE_COND_V2 = True control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x) return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
return return_value return return_value
dataset = dataset_ops.Dataset.range(100).apply( dataset = dataset_ops.Dataset.range(100).apply(

View File

@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import script_ops from tensorflow.python.ops import script_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
@ -409,42 +408,12 @@ def enable_control_flow_v2(fn):
""" """
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
control_flow_ops.ENABLE_COND_V2 = True
control_flow_ops.ENABLE_WHILE_V2 = True
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
try: try:
fn(*args, **kwargs) fn(*args, **kwargs)
finally: finally:
control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
return wrapper
def enable_tensor_array_v2(fn):
"""Decorator for enabling _GraphTensorArrayV2 on a test.
Note this enables _GraphTensorArrayV2 after running the test class's
setup/teardown methods.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
def wrapper(*args, **kwargs):
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
try:
fn(*args, **kwargs)
finally:
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
return wrapper return wrapper
@ -493,7 +462,7 @@ def with_control_flow_v2(cls):
Returns: Returns:
cls with new test methods added cls with new test methods added
""" """
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
return cls return cls
for name, value in cls.__dict__.copy().items(): for name, value in cls.__dict__.copy().items():

View File

@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_array_ops
@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase):
v1_msg = "The two structures don't have the same nested structure" v1_msg = "The two structures don't have the same nested structure"
v2_msg = "Outputs of true_fn and false_fn must have the same structure" v2_msg = "Outputs of true_fn and false_fn must have the same structure"
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg): ValueError,
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
r = control_flow_ops.cond(pred, fn1, fn2) r = control_flow_ops.cond(pred, fn1, fn2)
self.evaluate(r) self.evaluate(r)
@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0) self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
# v1 control flow gets None second derivative for some reason. # v1 control flow gets None second derivative for some reason.
if not control_flow_ops.ENABLE_COND_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertIsNone(grad_grad) self.assertIsNone(grad_grad)
return return
@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase):
# In defuns, all prints should execute in program order. # In defuns, all prints should execute in program order.
# This doesn't work with legacy control flow. # This doesn't work with legacy control flow.
if control_flow_ops.ENABLE_COND_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
@eager_function.defun @eager_function.defun
def cond(): def cond():
@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase):
# In defuns, all prints should execute in program order. # In defuns, all prints should execute in program order.
# This doesn't work with legacy control flow. # This doesn't work with legacy control flow.
if control_flow_ops.ENABLE_WHILE_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
@eager_function.defun @eager_function.defun
def while_loop(): def while_loop():
@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase):
gs = gradients_impl.gradients(loop_no_xla, v) gs = gradients_impl.gradients(loop_no_xla, v)
self.evaluate(gs) # This should execute without error. self.evaluate(gs) # This should execute without error.
if control_flow_ops.ENABLE_WHILE_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase):
lambda i, x: (i + 1, v * x), (0, 1.0), lambda i, x: (i + 1, v * x), (0, 1.0),
maximum_iterations=max_iter_holder[0]) maximum_iterations=max_iter_holder[0])
if control_flow_ops.ENABLE_WHILE_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
xla_context = control_flow_ops.XLAControlFlowContext() xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter() xla_context.Enter()
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(sess.run(grad, {pred: True}), 8.0) self.assertEqual(sess.run(grad, {pred: True}), 8.0)
self.assertEqual(sess.run(grad, {pred: False}), 0.0) self.assertEqual(sess.run(grad, {pred: False}), 0.0)
if not control_flow_ops.ENABLE_WHILE_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
return return
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0) self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase):
# outer_loop(x) = g(g(x)) = 4x + 81 # outer_loop(x) = g(g(x)) = 4x + 81
# outer_loop'(x) = 4 # outer_loop'(x) = 4
# Note that v1 control flow gets 4.0 as well if the cond is removed. # Note that v1 control flow gets 4.0 as well if the cond is removed.
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual(grad, 4.0) self.assertEqual(grad, 4.0)
def testWhile_NestedInput(self): def testWhile_NestedInput(self):
@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase):
result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
grad_theta = gradients_impl.gradients(result, theta) grad_theta = gradients_impl.gradients(result, theta)
if not control_flow_ops.ENABLE_WHILE_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
with self.assertRaisesRegexp(TypeError, "Second-order gradient"): with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
gradients_impl.gradients(grad_theta, theta) gradients_impl.gradients(grad_theta, theta)
grad_theta_stopped = array_ops.stop_gradient(grad_theta) grad_theta_stopped = array_ops.stop_gradient(grad_theta)
@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(r[1].eval(), 65536.0) self.assertEqual(r[1].eval(), 65536.0)
self.assertEqual(grad.eval(), 524288.0) self.assertEqual(grad.eval(), 524288.0)
# while_v2 does not have stacks. # while_v2 does not have stacks.
if not control_flow_ops.ENABLE_WHILE_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual( self.assertEqual(
len([op for op in x.graph.get_operations() if op.type == "StackV2" len([op for op in x.graph.get_operations() if op.type == "StackV2"
]), 1) ]), 1)

View File

@ -23,6 +23,7 @@ from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -30,14 +31,11 @@ from tensorflow.python.platform import test
class ControlFlowUtilV2Test(test.TestCase): class ControlFlowUtilV2Test(test.TestCase):
def setUp(self): def setUp(self):
self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2 self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
control_flow_ops.ENABLE_COND_V2 = True
control_flow_ops.ENABLE_WHILE_V2 = True
def tearDown(self): def tearDown(self):
control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old
control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old
def _create_control_flow(self, expect_in_defun): def _create_control_flow(self, expect_in_defun):
"""Helper method for testInDefun.""" """Helper method for testInDefun."""

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase):
@test_util.run_deprecated_v1 @test_util.run_deprecated_v1
def testSkipEagerTensorArrayGradGrad(self): def testSkipEagerTensorArrayGradGrad(self):
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.skipTest("Legacy TensorArray does not support double derivatives.") self.skipTest("Legacy TensorArray does not support double derivatives.")
with self.test_session(use_gpu=True) as session: with self.test_session(use_gpu=True) as session:
x = constant_op.constant(4.0) x = constant_op.constant(4.0)
@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase):
with self.session(use_gpu=True): with self.session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32) ta = _make_ta(3, "foo", dtype=dtypes.float32)
# Test writing the wrong datatype # Test writing the wrong datatype
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()): not context.executing_eagerly()):
error_msg = ("Invalid data types; op elements string but list elements " error_msg = ("Invalid data types; op elements string but list elements "
"float") "float")
@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg): with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(0, "wrong_type_scalar").flow) self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()): not context.executing_eagerly()):
error_msg = "Trying to modify element -1 in a list with 3 elements." error_msg = "Trying to modify element -1 in a list with 3 elements."
else: else:
@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg): with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(-1, 3.0).flow) self.evaluate(ta.write(-1, 3.0).flow)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()): not context.executing_eagerly()):
error_msg = "Trying to modify element 3 in a list with 3 elements" error_msg = "Trying to modify element 3 in a list with 3 elements"
else: else:
@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase):
# Test reading wrong datatype (only possible when constructing graphs). # Test reading wrong datatype (only possible when constructing graphs).
if (not context.executing_eagerly() and if (not context.executing_eagerly() and
not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2): not control_flow_util.ENABLE_CONTROL_FLOW_V2):
r0_bad = gen_data_flow_ops.tensor_array_read_v3( r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError( with self.assertRaisesOpError(
"TensorArray dtype is float but Op requested dtype double."): "TensorArray dtype is float but Op requested dtype double."):
self.evaluate(r0_bad) self.evaluate(r0_bad)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()): not context.executing_eagerly()):
error_msg = "Trying to access element -1 in a list with 3 elements." error_msg = "Trying to access element -1 in a list with 3 elements."
else: else:
@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg): with self.assertRaisesOpError(error_msg):
self.evaluate(ta.read(-1)) self.evaluate(ta.read(-1))
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()): not context.executing_eagerly()):
error_msg = "Trying to access element 3 in a list with 3 elements." error_msg = "Trying to access element 3 in a list with 3 elements."
else: else:
@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase):
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1" error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not in_eager_mode else not in_eager_mode else
r"Expected sum of lengths to be equal to values.shape\[0\], " r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]") r"but sum of lengths is 1 and value's shape is: \[3\]")
@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase):
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow) self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
ta = _make_ta(1, "baz") ta = _make_ta(1, "baz")
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode: if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
ValueError, "Shape must be at least rank 1 but is rank 0"): ValueError, "Shape must be at least rank 1 but is rank 0"):
self.evaluate(ta.split(1.0, [1]).flow) self.evaluate(ta.split(1.0, [1]).flow)
@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase):
): ):
self.evaluate(ta.split(1.0, [1]).flow) self.evaluate(ta.split(1.0, [1]).flow)
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode: if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
ta = _make_ta(2, "buz") ta = _make_ta(2, "buz")
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths " r"TensorArray's size is not equal to the size of lengths "
@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase):
# self._testWhileLoopWritePackGradients( # self._testWhileLoopWritePackGradients(
# dynamic_size=False, dtype=tf.int64) # dynamic_size=False, dtype=tf.int64)
@test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA")
@test_util.enable_tensor_array_v2
def testWhileLoopV1WithTensorArrayV2(self):
size = 3
ta = tensor_array_ops.TensorArray(
dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar())
def Body(counter, ta):
return counter + 1, ta.write(counter, counter)
_, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta])
for i in range(size):
self.assertEqual(self.evaluate(ta.read(i)), i)
@test_util.disable_control_flow_v2("b/117943489 (dynamic_size)") @test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
@test_util.run_v1_only("b/117943489") @test_util.run_v1_only("b/117943489")
def testSkipEagerWhileLoopDynamicWritePackGradients(self): def testSkipEagerWhileLoopDynamicWritePackGradients(self):
@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase):
self.assertEqual((2, 2), w0.read(1).get_shape()) self.assertEqual((2, 2), w0.read(1).get_shape())
else: else:
self.assertEqual(r0.get_shape().ndims, None) self.assertEqual(r0.get_shape().ndims, None)
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2: if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual( self.assertEqual(
tensor_shape.TensorShape( tensor_shape.TensorShape(
ta1.handle.op.get_attr("element_shape")).ndims, None) ta1.handle.op.get_attr("element_shape")).ndims, None)
@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase):
"TensorArray has size zero, but element shape <unknown> is not " "TensorArray has size zero, but element shape <unknown> is not "
"fully defined. Currently only static shapes are supported when " "fully defined. Currently only static shapes are supported when "
"packing zero-size TensorArrays.") "packing zero-size TensorArrays.")
with self.assertRaisesOpError(v2_msg if tensor_array_ops with self.assertRaisesOpError(
.ENABLE_TENSOR_ARRAY_V2 else v1_msg): v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
ta.stack().eval() ta.stack().eval()
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")

View File

@ -24,13 +24,11 @@ from __future__ import print_function
import abc import abc
import collections import collections
import functools import functools
import os
import six import six
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(),
while_v2 = LazyLoader("while_v2", globals(), while_v2 = LazyLoader("while_v2", globals(),
"tensorflow.python.ops.while_v2") "tensorflow.python.ops.while_v2")
ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0"
ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's # We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module. # existing 'tuple' for later use in this module.
_basetuple = tuple _basetuple = tuple
@ -2052,7 +2047,7 @@ def cond(pred,
``` ```
""" """
if ENABLE_COND_V2 and not context.executing_eagerly(): if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
return cond_v2.cond_v2(pred, true_fn, false_fn, name) return cond_v2.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for # We needed to make true_fn/false_fn keyword arguments for
@ -3487,7 +3482,7 @@ def while_loop(cond,
``` ```
""" """
if ENABLE_WHILE_V2 and not context.executing_eagerly(): if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
return while_v2.while_loop( return while_v2.while_loop(
cond, cond,
body, body,

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark):
iters=self.NUM_ITERS) iters=self.NUM_ITERS)
def benchmark_cond_v1_defun(self): def benchmark_cond_v1_defun(self):
old_val = control_flow_ops.ENABLE_COND_V2 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_ops.ENABLE_COND_V2 = False control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
self._benchmark_defun() self._benchmark_defun()
control_flow_ops.ENABLE_COND_V2 = old_val control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v2_defun(self): def benchmark_cond_v2_defun(self):
old_val = control_flow_ops.ENABLE_COND_V2 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_ops.ENABLE_COND_V2 = True control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
self._benchmark_defun() self._benchmark_defun()
control_flow_ops.ENABLE_COND_V2 = old_val control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v1_graph(self): def benchmark_cond_v1_graph(self):
old_val = control_flow_ops.ENABLE_COND_V2 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_ops.ENABLE_COND_V2 = False control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
self._benchmark_graph() self._benchmark_graph()
control_flow_ops.ENABLE_COND_V2 = old_val control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v2_graph(self): def benchmark_cond_v2_graph(self):
old_val = control_flow_ops.ENABLE_COND_V2 old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_ops.ENABLE_COND_V2 = True control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
self._benchmark_graph() self._benchmark_graph()
control_flow_ops.ENABLE_COND_V2 = old_val control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
if __name__ == "__main__": if __name__ == "__main__":
ops.enable_eager_execution() ops.enable_eager_execution()

View File

@ -23,10 +23,18 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import traceback import traceback
from tensorflow.python import tf2
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
def IsInXLAContext(op): def IsInXLAContext(op):
try: try:

View File

@ -20,10 +20,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import contextlib import contextlib
import os
import weakref import weakref
from tensorflow.python import tf2
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -32,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import list_ops from tensorflow.python.ops import list_ops
@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
ENABLE_TENSOR_ARRAY_V2 = (
tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None)
# _GraphTensorArray accesses many of the hidden generated ops, but is in # _GraphTensorArray accesses many of the hidden generated ops, but is in
# fact built to wrap these methods. # fact built to wrap these methods.
# pylint: disable=protected-access # pylint: disable=protected-access
@ -1013,7 +1008,7 @@ class TensorArray(object):
if context.executing_eagerly(): if context.executing_eagerly():
implementation = _EagerTensorArray implementation = _EagerTensorArray
else: else:
if ENABLE_TENSOR_ARRAY_V2: if control_flow_util.ENABLE_CONTROL_FLOW_V2:
implementation = _GraphTensorArrayV2 implementation = _GraphTensorArrayV2
else: else:
implementation = _GraphTensorArray implementation = _GraphTensorArray

View File

@ -52,13 +52,6 @@ from tensorflow.python.util import nest
# to them and then pass those in as data inputs. This should probably be # to them and then pass those in as data inputs. This should probably be
# handled in the CapturingGraph itself. # handled in the CapturingGraph itself.
# Op types that output a resource tensor representing a TensorArray handle.
TENSOR_ARRAY_HANDLE_OPS = (
"TensorArrayV3",
"TensorArrayGradV3",
"TensorArrayGradWithShape",
)
def while_loop(cond, def while_loop(cond,
body, body,
@ -257,24 +250,19 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
"_maximum_iterations") if _is_in_xla_context() else None "_maximum_iterations") if _is_in_xla_context() else None
assert not _is_in_xla_context() or maximum_iterations is not None assert not _is_in_xla_context() or maximum_iterations is not None
# Set the incoming gradient of TensorArray handles to None. The gradient # Set the incoming gradient of non-trainable inputs to None. It is possible
# implementation currently assumes all resource tensors correspond to float32 # that we receive non-None gradients for non-trainable types in nested while
# ResourceVariables, which can lead to runtime shape errors when used with a # loops because we accumulate outputs of the inner while as variant tensors
# TensorArray. This is a workaround until TensorArrays are reimplemented with # which are trainable and hence receive zeros_like tensors in the gradient
# TensorLists instead of resources. # pass. The non-trainable tensors then receive the popped zeros tensor from
# Also set the incoming gradient of non-trainable inputs to None. It is # this zeros variant. The gradient for the loop vars corresponding to these
# possible that we receive non-None gradients for non-trainable types in # tensors is None or zeros (this happens only if the loop var is accumulated
# nested while loops because we accumulate outputs of the inner while as # as well) in _grad_fn so we reset these.
# variant tensors which are trainable and hence receive zeros_like tensors in
# the gradient pass. The non-trainable tensors then receive the popped zeros
# tensor from this zeros variant. The gradient for the loop vars corresponding
# to these tensors is None or zeros (this happens only if the loop var is
# accumulated as well) in _grad_fn so we reset these.
# TODO(b/118712257): Remove the IsTrainable filter once we can handle None # TODO(b/118712257): Remove the IsTrainable filter once we can handle None
# output grads in _grad_fn. # output grads in _grad_fn.
grads = [ grads = [
None if _is_tensor_array_handle(output) or not _is_trainable(output) None if not _is_trainable(output) else grad
else grad for grad, output in zip(grads, body_graph.outputs) for grad, output in zip(grads, body_graph.outputs)
] ]
# Ensure that all non-resource trainable outputs have incoming gradients. # Ensure that all non-resource trainable outputs have incoming gradients.
@ -339,8 +327,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
# See comment in while_loop. # See comment in while_loop.
outputs = [array_ops.identity(t) for t in outputs] outputs = [array_ops.identity(t) for t in outputs]
# Set None as the output gradient for tensors with None input gradient # Set None as the output gradient for tensors with None input gradient.
# e.g. TensorArray handles.
# outputs[0] is the loop counter. # outputs[0] is the loop counter.
# outputs[1] is the total number of loop iterations. # outputs[1] is the total number of loop iterations.
index = 2 index = 2
@ -853,28 +840,6 @@ def _graph_name(graph):
return "Base" return "Base"
def _is_tensor_array_handle(tensor):
"""Returns whether tensor is a TensorArray handle."""
if tensor.dtype != dtypes.resource:
return False
if tensor.op.type == "While":
# We assume that any resource outputs of a While op correspond to a captured
# resource input (as opposed to a loop variable specified by the user).
# NOTE(skyewm): we could actually check this, but I can't think of when you
# would have a resource loop variable.
tensor = tensor.op.inputs[tensor.value_index]
# TODO(b/118452219): add test coverage for this.
tensor = func_graph_module.maybe_captured(tensor)
if isinstance(tensor, ops.EagerTensor):
# Eager execution doesn't quite support legacy tensorarray
return False
return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
def _pack_sequence_as(structure_with_tas, loop_vars): def _pack_sequence_as(structure_with_tas, loop_vars):
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""