Explain when callables passed to tf.cond & tf.while_loop are run.

Rename the parameters to tf.cond.
Change: 154774725
This commit is contained in:
A. Unique TensorFlower 2017-05-01 15:11:37 -08:00 committed by TensorFlower Gardener
parent 8a123f7d1b
commit 7c561e09c0
3 changed files with 137 additions and 39 deletions

View File

@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res):
return res
def cond(pred, fn1, fn2, strict=False, name=None):
"""Return `fn1()` if the boolean predicate `pred` is true else `fn2()`.
# pylint: disable=g-doc-args
@deprecation.deprecated_args(
None,
"fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
fn1=None, fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
the same non-zero number and type of outputs.
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
Note that the conditional execution applies only to the operations defined in
`fn1` and `fn2`. Consider the following simple program:
`true_fn` and `false_fn`. Consider the following simple program:
```python
z = tf.multiply(a, b)
@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
Although this behavior is consistent with the dataflow model of TensorFlow,
it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
stitches together the graph fragments created during the `true_fn` and
`false_fn` calls with some additional graph nodes to ensure that the right
branch gets executed depending on the value of `pred`.
`tf.cond` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples.
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`.
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
This behavior is disabled by passing `strict=True`.
Args:
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables 'strict' mode; see above.
name: Optional name prefix for the returned tensors.
Returns:
Tensors returned by the call to either `fn1` or `fn2`. If the callables
return a singleton list, the element is extracted from the list.
Tensors returned by the call to either `true_fn` or `false_fn`. If the
callables return a singleton list, the element is extracted from the list.
Raises:
TypeError: if `fn1` or `fn2` is not callable.
ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
return tensors of different types.
TypeError: if `true_fn` or `false_fn` is not callable.
ValueError: if `true_fn` and `false_fn` do not return the same number of
tensors, or return tensors of different types.
Example:
@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
```
"""
with ops.name_scope(name, "cond", [pred]) as name:
if not callable(fn1):
raise TypeError("fn1 must be callable.")
if not callable(fn2):
raise TypeError("fn2 must be callable.")
# We needed to make true_fn/false_fn keyword arguments for
# backwards-compatibility. This check exists so that we can convert back to
# having them be positional arguments.
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
# `fn1` and `fn2` are deleted.
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")
if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]) as name:
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
# Build the graph for the true branch in a new context.
context_t = CondContext(pred, pivot_1, branch=1)
context_t.Enter()
orig_res_t, res_t = context_t.BuildCondBranch(fn1)
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
if orig_res_t is None:
raise ValueError("fn1 must have a return value.")
raise ValueError("true_fn must have a return value.")
context_t.ExitResult(res_t)
context_t.Exit()
# Build the graph for the false branch in a new context.
context_f = CondContext(pred, pivot_2, branch=0)
context_f.Enter()
orig_res_f, res_f = context_f.BuildCondBranch(fn2)
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
if orig_res_f is None:
raise ValueError("fn2 must have a return value.")
raise ValueError("false_fn must have a return value.")
context_f.ExitResult(res_f)
context_f.Exit()
@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of fn1 and fn2: {}".format(e))
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of fn1 and fn2: {}".format(e))
"Incompatible return values of true_fn and false_fn: {}".format(e))
# Add the final merge to the graph.
if not res_t:
raise ValueError("fn1 and fn2 must return at least one result.")
raise ValueError("true_fn and false_fn must return at least one result.")
res_t_flat = nest.flatten(res_t)
res_f_flat = nest.flatten(res_f)
@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
val_x = x if isinstance(x, ops.Tensor) else x.values
val_y = y if isinstance(y, ops.Tensor) else y.values
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
raise ValueError(
"Outputs of true_fn and false_fn must have the same type: %s, %s" %
(val_x.dtype.name, val_y.dtype.name))
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
if not strict:
merges = _UnpackIfSingleton(merges)
return merges
# pylint: enable=g-doc-args
def _resource_safe_shape(t):
@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
`cond` and `body`. `cond` and `body` both take as many arguments as there are
`loop_vars`.
While `cond` evaluates to true, `body` is executed.
In addition to regular Tensors or IndexedSlices, the body may accept and
return TensorArray objects. The flows of the TensorArray objects will
be appropriately forwarded between loops and during gradient calculations.
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
stitches together the graph fragments created during the `cond` and `body`
calls with some additional graph nodes to make something the repeats
`body` until `cond` returns false.
For correctness, `tf.while_loop()` strictly enforces shape invariants for
the loop variables. A shape invariant is a (possibly partial) shape that
is unchanged across the iterations of the loop. An error will be raised
@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
operation returns the tensors generated by `default`.
`tf.case` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
`tensorflow.python.util.nest`. All of the callables must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
a callable, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`.
Example 1:
@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
Expressions:
```
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)

View File

@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
self.assertEquals(grad_x_false.eval(), 0.)
class CondTest(TensorFlowTestCase):
def testCondTrue(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalse(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondTrueLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalseLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondMissingArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, false_fn=lambda: x)
def testCondMissingArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x)
def testCondDuplicateArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
def testCondDuplicateArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
class ContextTest(TensorFlowTestCase):
def testCondContext(self):

View File

@ -718,7 +718,7 @@ tf_module {
}
member_method {
name: "cond"
argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "confusion_matrix"