Explain when callables passed to tf.cond & tf.while_loop are run.
Rename the parameters to tf.cond. Change: 154774725
This commit is contained in:
parent
8a123f7d1b
commit
7c561e09c0
@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.gen_control_flow_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_should_use
|
||||
|
||||
@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res):
|
||||
return res
|
||||
|
||||
|
||||
def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
"""Return `fn1()` if the boolean predicate `pred` is true else `fn2()`.
|
||||
# pylint: disable=g-doc-args
|
||||
@deprecation.deprecated_args(
|
||||
None,
|
||||
"fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
|
||||
"fn1", "fn2")
|
||||
def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
|
||||
fn1=None, fn2=None):
|
||||
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
|
||||
|
||||
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
|
||||
the same non-zero number and type of outputs.
|
||||
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
|
||||
`false_fn` must have the same non-zero number and type of outputs.
|
||||
|
||||
Note that the conditional execution applies only to the operations defined in
|
||||
`fn1` and `fn2`. Consider the following simple program:
|
||||
`true_fn` and `false_fn`. Consider the following simple program:
|
||||
|
||||
```python
|
||||
z = tf.multiply(a, b)
|
||||
@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
Although this behavior is consistent with the dataflow model of TensorFlow,
|
||||
it has occasionally surprised some users who expected a lazier semantics.
|
||||
|
||||
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
|
||||
call to `cond`, and not at all during `Session.run()`). `cond`
|
||||
stitches together the graph fragments created during the `true_fn` and
|
||||
`false_fn` calls with some additional graph nodes to ensure that the right
|
||||
branch gets executed depending on the value of `pred`.
|
||||
|
||||
`tf.cond` supports nested structures as implemented in
|
||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
|
||||
same (possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
||||
behavior is disabled by passing `strict=True`.
|
||||
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
|
||||
This behavior is disabled by passing `strict=True`.
|
||||
|
||||
Args:
|
||||
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
|
||||
fn1: The callable to be performed if pred is true.
|
||||
fn2: The callable to be performed if pred is false.
|
||||
pred: A scalar determining whether to return the result of `true_fn` or
|
||||
`false_fn`.
|
||||
true_fn: The callable to be performed if pred is true.
|
||||
false_fn: The callable to be performed if pred is false.
|
||||
strict: A boolean that enables/disables 'strict' mode; see above.
|
||||
name: Optional name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
Tensors returned by the call to either `fn1` or `fn2`. If the callables
|
||||
return a singleton list, the element is extracted from the list.
|
||||
Tensors returned by the call to either `true_fn` or `false_fn`. If the
|
||||
callables return a singleton list, the element is extracted from the list.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn1` or `fn2` is not callable.
|
||||
ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
|
||||
return tensors of different types.
|
||||
TypeError: if `true_fn` or `false_fn` is not callable.
|
||||
ValueError: if `true_fn` and `false_fn` do not return the same number of
|
||||
tensors, or return tensors of different types.
|
||||
|
||||
Example:
|
||||
|
||||
@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
```
|
||||
|
||||
"""
|
||||
with ops.name_scope(name, "cond", [pred]) as name:
|
||||
if not callable(fn1):
|
||||
raise TypeError("fn1 must be callable.")
|
||||
if not callable(fn2):
|
||||
raise TypeError("fn2 must be callable.")
|
||||
# We needed to make true_fn/false_fn keyword arguments for
|
||||
# backwards-compatibility. This check exists so that we can convert back to
|
||||
# having them be positional arguments.
|
||||
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
|
||||
# `fn1` and `fn2` are deleted.
|
||||
if fn1 is not None:
|
||||
if true_fn is not None:
|
||||
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
|
||||
true_fn = fn1
|
||||
elif true_fn is None:
|
||||
raise TypeError("cond(): true_fn argument required")
|
||||
if fn2 is not None:
|
||||
if false_fn is not None:
|
||||
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
|
||||
false_fn = fn2
|
||||
elif false_fn is None:
|
||||
raise TypeError("cond(): false_fn argument required")
|
||||
|
||||
if not callable(true_fn):
|
||||
raise TypeError("true_fn must be callable.")
|
||||
if not callable(false_fn):
|
||||
raise TypeError("false_fn must be callable.")
|
||||
|
||||
with ops.name_scope(name, "cond", [pred]) as name:
|
||||
# Add the Switch to the graph.
|
||||
if isinstance(pred, bool):
|
||||
raise TypeError("pred must not be a Python bool")
|
||||
@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
# Build the graph for the true branch in a new context.
|
||||
context_t = CondContext(pred, pivot_1, branch=1)
|
||||
context_t.Enter()
|
||||
orig_res_t, res_t = context_t.BuildCondBranch(fn1)
|
||||
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
|
||||
if orig_res_t is None:
|
||||
raise ValueError("fn1 must have a return value.")
|
||||
raise ValueError("true_fn must have a return value.")
|
||||
context_t.ExitResult(res_t)
|
||||
context_t.Exit()
|
||||
|
||||
# Build the graph for the false branch in a new context.
|
||||
context_f = CondContext(pred, pivot_2, branch=0)
|
||||
context_f.Enter()
|
||||
orig_res_f, res_f = context_f.BuildCondBranch(fn2)
|
||||
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
|
||||
if orig_res_f is None:
|
||||
raise ValueError("fn2 must have a return value.")
|
||||
raise ValueError("false_fn must have a return value.")
|
||||
context_f.ExitResult(res_f)
|
||||
context_f.Exit()
|
||||
|
||||
@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
nest.assert_same_structure(orig_res_t, orig_res_f)
|
||||
except TypeError as e:
|
||||
raise TypeError(
|
||||
"Incompatible return types of fn1 and fn2: {}".format(e))
|
||||
"Incompatible return types of true_fn and false_fn: {}".format(e))
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Incompatible return values of fn1 and fn2: {}".format(e))
|
||||
"Incompatible return values of true_fn and false_fn: {}".format(e))
|
||||
|
||||
# Add the final merge to the graph.
|
||||
if not res_t:
|
||||
raise ValueError("fn1 and fn2 must return at least one result.")
|
||||
raise ValueError("true_fn and false_fn must return at least one result.")
|
||||
|
||||
res_t_flat = nest.flatten(res_t)
|
||||
res_f_flat = nest.flatten(res_f)
|
||||
@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
val_x = x if isinstance(x, ops.Tensor) else x.values
|
||||
val_y = y if isinstance(y, ops.Tensor) else y.values
|
||||
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
|
||||
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
|
||||
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
|
||||
raise ValueError(
|
||||
"Outputs of true_fn and false_fn must have the same type: %s, %s" %
|
||||
(val_x.dtype.name, val_y.dtype.name))
|
||||
|
||||
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
|
||||
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
|
||||
@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
||||
if not strict:
|
||||
merges = _UnpackIfSingleton(merges)
|
||||
return merges
|
||||
# pylint: enable=g-doc-args
|
||||
|
||||
|
||||
def _resource_safe_shape(t):
|
||||
@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
|
||||
`cond` and `body`. `cond` and `body` both take as many arguments as there are
|
||||
`loop_vars`.
|
||||
|
||||
While `cond` evaluates to true, `body` is executed.
|
||||
|
||||
In addition to regular Tensors or IndexedSlices, the body may accept and
|
||||
return TensorArray objects. The flows of the TensorArray objects will
|
||||
be appropriately forwarded between loops and during gradient calculations.
|
||||
|
||||
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
|
||||
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
|
||||
stitches together the graph fragments created during the `cond` and `body`
|
||||
calls with some additional graph nodes to make something the repeats
|
||||
`body` until `cond` returns false.
|
||||
|
||||
For correctness, `tf.while_loop()` strictly enforces shape invariants for
|
||||
the loop variables. A shape invariant is a (possibly partial) shape that
|
||||
is unchanged across the iterations of the loop. An error will be raised
|
||||
@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
|
||||
operation returns the tensors generated by `default`.
|
||||
|
||||
`tf.case` supports nested structures as implemented in
|
||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
||||
`tensorflow.python.util.nest`. All of the callables must return the same
|
||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
||||
a callable, they are implicitly unpacked to single values. This
|
||||
behavior is disabled by passing `strict=True`.
|
||||
|
||||
Example 1:
|
||||
@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
|
||||
|
||||
Expressions:
|
||||
```
|
||||
x = tf.constant(0)
|
||||
y = tf.constant(1)
|
||||
z = tf.constant(2)
|
||||
def f1(): return tf.constant(17)
|
||||
def f2(): return tf.constant(23)
|
||||
def f3(): return tf.constant(-1)
|
||||
|
@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
|
||||
self.assertEquals(grad_x_false.eval(), 0.)
|
||||
|
||||
|
||||
class CondTest(TensorFlowTestCase):
|
||||
|
||||
def testCondTrue(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(2)
|
||||
y = constant_op.constant(5)
|
||||
z = control_flow_ops.cond(
|
||||
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
|
||||
lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(z.eval(), 34)
|
||||
|
||||
def testCondFalse(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(2)
|
||||
y = constant_op.constant(1)
|
||||
z = control_flow_ops.cond(
|
||||
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
|
||||
lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(z.eval(), 24)
|
||||
|
||||
def testCondTrueLegacy(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(2)
|
||||
y = constant_op.constant(5)
|
||||
z = control_flow_ops.cond(
|
||||
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
|
||||
fn2=lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(z.eval(), 34)
|
||||
|
||||
def testCondFalseLegacy(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(2)
|
||||
y = constant_op.constant(1)
|
||||
z = control_flow_ops.cond(
|
||||
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
|
||||
fn2=lambda: math_ops.add(y, 23))
|
||||
self.assertEquals(z.eval(), 24)
|
||||
|
||||
def testCondMissingArg1(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.cond(True, false_fn=lambda: x)
|
||||
|
||||
def testCondMissingArg2(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.cond(True, lambda: x)
|
||||
|
||||
def testCondDuplicateArg1(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
|
||||
|
||||
def testCondDuplicateArg2(self):
|
||||
with self.test_session():
|
||||
x = constant_op.constant(1)
|
||||
with self.assertRaises(TypeError):
|
||||
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
|
||||
|
||||
|
||||
class ContextTest(TensorFlowTestCase):
|
||||
|
||||
def testCondContext(self):
|
||||
|
@ -718,7 +718,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "cond"
|
||||
argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
||||
argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "confusion_matrix"
|
||||
|
Loading…
Reference in New Issue
Block a user