Explain when callables passed to tf.cond & tf.while_loop are run.

Rename the parameters to tf.cond.
Change: 154774725
This commit is contained in:
A. Unique TensorFlower 2017-05-01 15:11:37 -08:00 committed by TensorFlower Gardener
parent 8a123f7d1b
commit 7c561e09c0
3 changed files with 137 additions and 39 deletions

View File

@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.gen_control_flow_ops import * from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use from tensorflow.python.util import tf_should_use
@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res):
return res return res
def cond(pred, fn1, fn2, strict=False, name=None): # pylint: disable=g-doc-args
"""Return `fn1()` if the boolean predicate `pred` is true else `fn2()`. @deprecation.deprecated_args(
None,
"fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
fn1=None, fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
the same non-zero number and type of outputs. `false_fn` must have the same non-zero number and type of outputs.
Note that the conditional execution applies only to the operations defined in Note that the conditional execution applies only to the operations defined in
`fn1` and `fn2`. Consider the following simple program: `true_fn` and `false_fn`. Consider the following simple program:
```python ```python
z = tf.multiply(a, b) z = tf.multiply(a, b)
@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
Although this behavior is consistent with the dataflow model of TensorFlow, Although this behavior is consistent with the dataflow model of TensorFlow,
it has occasionally surprised some users who expected a lazier semantics. it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
stitches together the graph fragments created during the `true_fn` and
`false_fn` calls with some additional graph nodes to ensure that the right
branch gets executed depending on the value of `pred`.
`tf.cond` supports nested structures as implemented in `tf.cond` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
(possibly nested) value structure of lists, tuples, and/or named tuples. same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
behavior is disabled by passing `strict=True`. This behavior is disabled by passing `strict=True`.
Args: Args:
pred: A scalar determining whether to return the result of `fn1` or `fn2`. pred: A scalar determining whether to return the result of `true_fn` or
fn1: The callable to be performed if pred is true. `false_fn`.
fn2: The callable to be performed if pred is false. true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables 'strict' mode; see above. strict: A boolean that enables/disables 'strict' mode; see above.
name: Optional name prefix for the returned tensors. name: Optional name prefix for the returned tensors.
Returns: Returns:
Tensors returned by the call to either `fn1` or `fn2`. If the callables Tensors returned by the call to either `true_fn` or `false_fn`. If the
return a singleton list, the element is extracted from the list. callables return a singleton list, the element is extracted from the list.
Raises: Raises:
TypeError: if `fn1` or `fn2` is not callable. TypeError: if `true_fn` or `false_fn` is not callable.
ValueError: if `fn1` and `fn2` do not return the same number of tensors, or ValueError: if `true_fn` and `false_fn` do not return the same number of
return tensors of different types. tensors, or return tensors of different types.
Example: Example:
@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
``` ```
""" """
with ops.name_scope(name, "cond", [pred]) as name: # We needed to make true_fn/false_fn keyword arguments for
if not callable(fn1): # backwards-compatibility. This check exists so that we can convert back to
raise TypeError("fn1 must be callable.") # having them be positional arguments.
if not callable(fn2): # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
raise TypeError("fn2 must be callable.") # `fn1` and `fn2` are deleted.
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")
if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]) as name:
# Add the Switch to the graph. # Add the Switch to the graph.
if isinstance(pred, bool): if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool") raise TypeError("pred must not be a Python bool")
@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
# Build the graph for the true branch in a new context. # Build the graph for the true branch in a new context.
context_t = CondContext(pred, pivot_1, branch=1) context_t = CondContext(pred, pivot_1, branch=1)
context_t.Enter() context_t.Enter()
orig_res_t, res_t = context_t.BuildCondBranch(fn1) orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
if orig_res_t is None: if orig_res_t is None:
raise ValueError("fn1 must have a return value.") raise ValueError("true_fn must have a return value.")
context_t.ExitResult(res_t) context_t.ExitResult(res_t)
context_t.Exit() context_t.Exit()
# Build the graph for the false branch in a new context. # Build the graph for the false branch in a new context.
context_f = CondContext(pred, pivot_2, branch=0) context_f = CondContext(pred, pivot_2, branch=0)
context_f.Enter() context_f.Enter()
orig_res_f, res_f = context_f.BuildCondBranch(fn2) orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
if orig_res_f is None: if orig_res_f is None:
raise ValueError("fn2 must have a return value.") raise ValueError("false_fn must have a return value.")
context_f.ExitResult(res_f) context_f.ExitResult(res_f)
context_f.Exit() context_f.Exit()
@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
nest.assert_same_structure(orig_res_t, orig_res_f) nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e: except TypeError as e:
raise TypeError( raise TypeError(
"Incompatible return types of fn1 and fn2: {}".format(e)) "Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
"Incompatible return values of fn1 and fn2: {}".format(e)) "Incompatible return values of true_fn and false_fn: {}".format(e))
# Add the final merge to the graph. # Add the final merge to the graph.
if not res_t: if not res_t:
raise ValueError("fn1 and fn2 must return at least one result.") raise ValueError("true_fn and false_fn must return at least one result.")
res_t_flat = nest.flatten(res_t) res_t_flat = nest.flatten(res_t)
res_f_flat = nest.flatten(res_f) res_f_flat = nest.flatten(res_f)
@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
val_x = x if isinstance(x, ops.Tensor) else x.values val_x = x if isinstance(x, ops.Tensor) else x.values
val_y = y if isinstance(y, ops.Tensor) else y.values val_y = y if isinstance(y, ops.Tensor) else y.values
if val_x.dtype.base_dtype != val_y.dtype.base_dtype: if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
raise ValueError("Outputs of fn1 and fn2 must have the same type: " raise ValueError(
"%s, %s" % (val_x.dtype.name, val_y.dtype.name)) "Outputs of true_fn and false_fn must have the same type: %s, %s" %
(val_x.dtype.name, val_y.dtype.name))
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
if not strict: if not strict:
merges = _UnpackIfSingleton(merges) merges = _UnpackIfSingleton(merges)
return merges return merges
# pylint: enable=g-doc-args
def _resource_safe_shape(t): def _resource_safe_shape(t):
@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
`cond` and `body`. `cond` and `body` both take as many arguments as there are `cond` and `body`. `cond` and `body` both take as many arguments as there are
`loop_vars`. `loop_vars`.
While `cond` evaluates to true, `body` is executed.
In addition to regular Tensors or IndexedSlices, the body may accept and In addition to regular Tensors or IndexedSlices, the body may accept and
return TensorArray objects. The flows of the TensorArray objects will return TensorArray objects. The flows of the TensorArray objects will
be appropriately forwarded between loops and during gradient calculations. be appropriately forwarded between loops and during gradient calculations.
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
stitches together the graph fragments created during the `cond` and `body`
calls with some additional graph nodes to make something the repeats
`body` until `cond` returns false.
For correctness, `tf.while_loop()` strictly enforces shape invariants for For correctness, `tf.while_loop()` strictly enforces shape invariants for
the loop variables. A shape invariant is a (possibly partial) shape that the loop variables. A shape invariant is a (possibly partial) shape that
is unchanged across the iterations of the loop. An error will be raised is unchanged across the iterations of the loop. An error will be raised
@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
operation returns the tensors generated by `default`. operation returns the tensors generated by `default`.
`tf.case` supports nested structures as implemented in `tf.case` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same `tensorflow.python.util.nest`. All of the callables must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples. (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This a callable, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`. behavior is disabled by passing `strict=True`.
Example 1: Example 1:
@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
Expressions: Expressions:
``` ```
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17) def f1(): return tf.constant(17)
def f2(): return tf.constant(23) def f2(): return tf.constant(23)
def f3(): return tf.constant(-1) def f3(): return tf.constant(-1)

View File

@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
self.assertEquals(grad_x_false.eval(), 0.) self.assertEquals(grad_x_false.eval(), 0.)
class CondTest(TensorFlowTestCase):
def testCondTrue(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalse(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondTrueLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalseLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondMissingArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, false_fn=lambda: x)
def testCondMissingArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x)
def testCondDuplicateArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
def testCondDuplicateArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
class ContextTest(TensorFlowTestCase): class ContextTest(TensorFlowTestCase):
def testCondContext(self): def testCondContext(self):

View File

@ -718,7 +718,7 @@ tf_module {
} }
member_method { member_method {
name: "cond" name: "cond"
argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "confusion_matrix" name: "confusion_matrix"