Explain when callables passed to tf.cond & tf.while_loop are run.
Rename the parameters to tf.cond. Change: 154774725
This commit is contained in:
parent
8a123f7d1b
commit
7c561e09c0
@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops
|
|||||||
from tensorflow.python.ops.gen_control_flow_ops import *
|
from tensorflow.python.ops.gen_control_flow_ops import *
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_should_use
|
from tensorflow.python.util import tf_should_use
|
||||||
|
|
||||||
@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def cond(pred, fn1, fn2, strict=False, name=None):
|
# pylint: disable=g-doc-args
|
||||||
"""Return `fn1()` if the boolean predicate `pred` is true else `fn2()`.
|
@deprecation.deprecated_args(
|
||||||
|
None,
|
||||||
|
"fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
|
||||||
|
"fn1", "fn2")
|
||||||
|
def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
|
||||||
|
fn1=None, fn2=None):
|
||||||
|
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
|
||||||
|
|
||||||
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
|
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
|
||||||
the same non-zero number and type of outputs.
|
`false_fn` must have the same non-zero number and type of outputs.
|
||||||
|
|
||||||
Note that the conditional execution applies only to the operations defined in
|
Note that the conditional execution applies only to the operations defined in
|
||||||
`fn1` and `fn2`. Consider the following simple program:
|
`true_fn` and `false_fn`. Consider the following simple program:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
z = tf.multiply(a, b)
|
z = tf.multiply(a, b)
|
||||||
@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
Although this behavior is consistent with the dataflow model of TensorFlow,
|
Although this behavior is consistent with the dataflow model of TensorFlow,
|
||||||
it has occasionally surprised some users who expected a lazier semantics.
|
it has occasionally surprised some users who expected a lazier semantics.
|
||||||
|
|
||||||
|
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
|
||||||
|
call to `cond`, and not at all during `Session.run()`). `cond`
|
||||||
|
stitches together the graph fragments created during the `true_fn` and
|
||||||
|
`false_fn` calls with some additional graph nodes to ensure that the right
|
||||||
|
branch gets executed depending on the value of `pred`.
|
||||||
|
|
||||||
`tf.cond` supports nested structures as implemented in
|
`tf.cond` supports nested structures as implemented in
|
||||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
|
||||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
same (possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
|
||||||
behavior is disabled by passing `strict=True`.
|
This behavior is disabled by passing `strict=True`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
|
pred: A scalar determining whether to return the result of `true_fn` or
|
||||||
fn1: The callable to be performed if pred is true.
|
`false_fn`.
|
||||||
fn2: The callable to be performed if pred is false.
|
true_fn: The callable to be performed if pred is true.
|
||||||
|
false_fn: The callable to be performed if pred is false.
|
||||||
strict: A boolean that enables/disables 'strict' mode; see above.
|
strict: A boolean that enables/disables 'strict' mode; see above.
|
||||||
name: Optional name prefix for the returned tensors.
|
name: Optional name prefix for the returned tensors.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensors returned by the call to either `fn1` or `fn2`. If the callables
|
Tensors returned by the call to either `true_fn` or `false_fn`. If the
|
||||||
return a singleton list, the element is extracted from the list.
|
callables return a singleton list, the element is extracted from the list.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if `fn1` or `fn2` is not callable.
|
TypeError: if `true_fn` or `false_fn` is not callable.
|
||||||
ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
|
ValueError: if `true_fn` and `false_fn` do not return the same number of
|
||||||
return tensors of different types.
|
tensors, or return tensors of different types.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "cond", [pred]) as name:
|
# We needed to make true_fn/false_fn keyword arguments for
|
||||||
if not callable(fn1):
|
# backwards-compatibility. This check exists so that we can convert back to
|
||||||
raise TypeError("fn1 must be callable.")
|
# having them be positional arguments.
|
||||||
if not callable(fn2):
|
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
|
||||||
raise TypeError("fn2 must be callable.")
|
# `fn1` and `fn2` are deleted.
|
||||||
|
if fn1 is not None:
|
||||||
|
if true_fn is not None:
|
||||||
|
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
|
||||||
|
true_fn = fn1
|
||||||
|
elif true_fn is None:
|
||||||
|
raise TypeError("cond(): true_fn argument required")
|
||||||
|
if fn2 is not None:
|
||||||
|
if false_fn is not None:
|
||||||
|
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
|
||||||
|
false_fn = fn2
|
||||||
|
elif false_fn is None:
|
||||||
|
raise TypeError("cond(): false_fn argument required")
|
||||||
|
|
||||||
|
if not callable(true_fn):
|
||||||
|
raise TypeError("true_fn must be callable.")
|
||||||
|
if not callable(false_fn):
|
||||||
|
raise TypeError("false_fn must be callable.")
|
||||||
|
|
||||||
|
with ops.name_scope(name, "cond", [pred]) as name:
|
||||||
# Add the Switch to the graph.
|
# Add the Switch to the graph.
|
||||||
if isinstance(pred, bool):
|
if isinstance(pred, bool):
|
||||||
raise TypeError("pred must not be a Python bool")
|
raise TypeError("pred must not be a Python bool")
|
||||||
@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
# Build the graph for the true branch in a new context.
|
# Build the graph for the true branch in a new context.
|
||||||
context_t = CondContext(pred, pivot_1, branch=1)
|
context_t = CondContext(pred, pivot_1, branch=1)
|
||||||
context_t.Enter()
|
context_t.Enter()
|
||||||
orig_res_t, res_t = context_t.BuildCondBranch(fn1)
|
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
|
||||||
if orig_res_t is None:
|
if orig_res_t is None:
|
||||||
raise ValueError("fn1 must have a return value.")
|
raise ValueError("true_fn must have a return value.")
|
||||||
context_t.ExitResult(res_t)
|
context_t.ExitResult(res_t)
|
||||||
context_t.Exit()
|
context_t.Exit()
|
||||||
|
|
||||||
# Build the graph for the false branch in a new context.
|
# Build the graph for the false branch in a new context.
|
||||||
context_f = CondContext(pred, pivot_2, branch=0)
|
context_f = CondContext(pred, pivot_2, branch=0)
|
||||||
context_f.Enter()
|
context_f.Enter()
|
||||||
orig_res_f, res_f = context_f.BuildCondBranch(fn2)
|
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
|
||||||
if orig_res_f is None:
|
if orig_res_f is None:
|
||||||
raise ValueError("fn2 must have a return value.")
|
raise ValueError("false_fn must have a return value.")
|
||||||
context_f.ExitResult(res_f)
|
context_f.ExitResult(res_f)
|
||||||
context_f.Exit()
|
context_f.Exit()
|
||||||
|
|
||||||
@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
nest.assert_same_structure(orig_res_t, orig_res_f)
|
nest.assert_same_structure(orig_res_t, orig_res_f)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Incompatible return types of fn1 and fn2: {}".format(e))
|
"Incompatible return types of true_fn and false_fn: {}".format(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Incompatible return values of fn1 and fn2: {}".format(e))
|
"Incompatible return values of true_fn and false_fn: {}".format(e))
|
||||||
|
|
||||||
# Add the final merge to the graph.
|
# Add the final merge to the graph.
|
||||||
if not res_t:
|
if not res_t:
|
||||||
raise ValueError("fn1 and fn2 must return at least one result.")
|
raise ValueError("true_fn and false_fn must return at least one result.")
|
||||||
|
|
||||||
res_t_flat = nest.flatten(res_t)
|
res_t_flat = nest.flatten(res_t)
|
||||||
res_f_flat = nest.flatten(res_f)
|
res_f_flat = nest.flatten(res_f)
|
||||||
@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
val_x = x if isinstance(x, ops.Tensor) else x.values
|
val_x = x if isinstance(x, ops.Tensor) else x.values
|
||||||
val_y = y if isinstance(y, ops.Tensor) else y.values
|
val_y = y if isinstance(y, ops.Tensor) else y.values
|
||||||
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
|
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
|
||||||
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
|
raise ValueError(
|
||||||
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
|
"Outputs of true_fn and false_fn must have the same type: %s, %s" %
|
||||||
|
(val_x.dtype.name, val_y.dtype.name))
|
||||||
|
|
||||||
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
|
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
|
||||||
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
|
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
|
||||||
@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
|
|||||||
if not strict:
|
if not strict:
|
||||||
merges = _UnpackIfSingleton(merges)
|
merges = _UnpackIfSingleton(merges)
|
||||||
return merges
|
return merges
|
||||||
|
# pylint: enable=g-doc-args
|
||||||
|
|
||||||
|
|
||||||
def _resource_safe_shape(t):
|
def _resource_safe_shape(t):
|
||||||
@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
|
|||||||
`cond` and `body`. `cond` and `body` both take as many arguments as there are
|
`cond` and `body`. `cond` and `body` both take as many arguments as there are
|
||||||
`loop_vars`.
|
`loop_vars`.
|
||||||
|
|
||||||
While `cond` evaluates to true, `body` is executed.
|
|
||||||
|
|
||||||
In addition to regular Tensors or IndexedSlices, the body may accept and
|
In addition to regular Tensors or IndexedSlices, the body may accept and
|
||||||
return TensorArray objects. The flows of the TensorArray objects will
|
return TensorArray objects. The flows of the TensorArray objects will
|
||||||
be appropriately forwarded between loops and during gradient calculations.
|
be appropriately forwarded between loops and during gradient calculations.
|
||||||
|
|
||||||
|
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
|
||||||
|
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
|
||||||
|
stitches together the graph fragments created during the `cond` and `body`
|
||||||
|
calls with some additional graph nodes to make something the repeats
|
||||||
|
`body` until `cond` returns false.
|
||||||
|
|
||||||
For correctness, `tf.while_loop()` strictly enforces shape invariants for
|
For correctness, `tf.while_loop()` strictly enforces shape invariants for
|
||||||
the loop variables. A shape invariant is a (possibly partial) shape that
|
the loop variables. A shape invariant is a (possibly partial) shape that
|
||||||
is unchanged across the iterations of the loop. An error will be raised
|
is unchanged across the iterations of the loop. An error will be raised
|
||||||
@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
|
|||||||
operation returns the tensors generated by `default`.
|
operation returns the tensors generated by `default`.
|
||||||
|
|
||||||
`tf.case` supports nested structures as implemented in
|
`tf.case` supports nested structures as implemented in
|
||||||
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
|
`tensorflow.python.util.nest`. All of the callables must return the same
|
||||||
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
(possibly nested) value structure of lists, tuples, and/or named tuples.
|
||||||
Singleton lists and tuples form the only exceptions to this: when returned by
|
Singleton lists and tuples form the only exceptions to this: when returned by
|
||||||
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
|
a callable, they are implicitly unpacked to single values. This
|
||||||
behavior is disabled by passing `strict=True`.
|
behavior is disabled by passing `strict=True`.
|
||||||
|
|
||||||
Example 1:
|
Example 1:
|
||||||
@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
|
|||||||
|
|
||||||
Expressions:
|
Expressions:
|
||||||
```
|
```
|
||||||
x = tf.constant(0)
|
|
||||||
y = tf.constant(1)
|
|
||||||
z = tf.constant(2)
|
|
||||||
def f1(): return tf.constant(17)
|
def f1(): return tf.constant(17)
|
||||||
def f2(): return tf.constant(23)
|
def f2(): return tf.constant(23)
|
||||||
def f3(): return tf.constant(-1)
|
def f3(): return tf.constant(-1)
|
||||||
|
@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
|
|||||||
self.assertEquals(grad_x_false.eval(), 0.)
|
self.assertEquals(grad_x_false.eval(), 0.)
|
||||||
|
|
||||||
|
|
||||||
|
class CondTest(TensorFlowTestCase):
|
||||||
|
|
||||||
|
def testCondTrue(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(2)
|
||||||
|
y = constant_op.constant(5)
|
||||||
|
z = control_flow_ops.cond(
|
||||||
|
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
|
||||||
|
lambda: math_ops.add(y, 23))
|
||||||
|
self.assertEquals(z.eval(), 34)
|
||||||
|
|
||||||
|
def testCondFalse(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(2)
|
||||||
|
y = constant_op.constant(1)
|
||||||
|
z = control_flow_ops.cond(
|
||||||
|
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
|
||||||
|
lambda: math_ops.add(y, 23))
|
||||||
|
self.assertEquals(z.eval(), 24)
|
||||||
|
|
||||||
|
def testCondTrueLegacy(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(2)
|
||||||
|
y = constant_op.constant(5)
|
||||||
|
z = control_flow_ops.cond(
|
||||||
|
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
|
||||||
|
fn2=lambda: math_ops.add(y, 23))
|
||||||
|
self.assertEquals(z.eval(), 34)
|
||||||
|
|
||||||
|
def testCondFalseLegacy(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(2)
|
||||||
|
y = constant_op.constant(1)
|
||||||
|
z = control_flow_ops.cond(
|
||||||
|
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
|
||||||
|
fn2=lambda: math_ops.add(y, 23))
|
||||||
|
self.assertEquals(z.eval(), 24)
|
||||||
|
|
||||||
|
def testCondMissingArg1(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
control_flow_ops.cond(True, false_fn=lambda: x)
|
||||||
|
|
||||||
|
def testCondMissingArg2(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
control_flow_ops.cond(True, lambda: x)
|
||||||
|
|
||||||
|
def testCondDuplicateArg1(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
|
||||||
|
|
||||||
|
def testCondDuplicateArg2(self):
|
||||||
|
with self.test_session():
|
||||||
|
x = constant_op.constant(1)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
|
||||||
|
|
||||||
|
|
||||||
class ContextTest(TensorFlowTestCase):
|
class ContextTest(TensorFlowTestCase):
|
||||||
|
|
||||||
def testCondContext(self):
|
def testCondContext(self):
|
||||||
|
@ -718,7 +718,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "cond"
|
name: "cond"
|
||||||
argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
|
argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "confusion_matrix"
|
name: "confusion_matrix"
|
||||||
|
Loading…
Reference in New Issue
Block a user