diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index a22e4819d51..91be9ddbd78 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops.gen_control_flow_ops import * # pylint: enable=wildcard-import from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import deprecation from tensorflow.python.util import nest from tensorflow.python.util import tf_should_use @@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res): return res -def cond(pred, fn1, fn2, strict=False, name=None): - """Return `fn1()` if the boolean predicate `pred` is true else `fn2()`. +# pylint: disable=g-doc-args +@deprecation.deprecated_args( + None, + "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", + "fn1", "fn2") +def cond(pred, true_fn=None, false_fn=None, strict=False, name=None, + fn1=None, fn2=None): + """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. - `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have - the same non-zero number and type of outputs. + `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and + `false_fn` must have the same non-zero number and type of outputs. Note that the conditional execution applies only to the operations defined in - `fn1` and `fn2`. Consider the following simple program: + `true_fn` and `false_fn`. Consider the following simple program: ```python z = tf.multiply(a, b) @@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None): Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics. + Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the + call to `cond`, and not at all during `Session.run()`). `cond` + stitches together the graph fragments created during the `true_fn` and + `false_fn` calls with some additional graph nodes to ensure that the right + branch gets executed depending on the value of `pred`. + `tf.cond` supports nested structures as implemented in - `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same - (possibly nested) value structure of lists, tuples, and/or named tuples. + `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the + same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by - `fn1` and/or `fn2`, they are implicitly unpacked to single values. This - behavior is disabled by passing `strict=True`. + `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. + This behavior is disabled by passing `strict=True`. Args: - pred: A scalar determining whether to return the result of `fn1` or `fn2`. - fn1: The callable to be performed if pred is true. - fn2: The callable to be performed if pred is false. + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. strict: A boolean that enables/disables 'strict' mode; see above. name: Optional name prefix for the returned tensors. Returns: - Tensors returned by the call to either `fn1` or `fn2`. If the callables - return a singleton list, the element is extracted from the list. + Tensors returned by the call to either `true_fn` or `false_fn`. If the + callables return a singleton list, the element is extracted from the list. Raises: - TypeError: if `fn1` or `fn2` is not callable. - ValueError: if `fn1` and `fn2` do not return the same number of tensors, or - return tensors of different types. + TypeError: if `true_fn` or `false_fn` is not callable. + ValueError: if `true_fn` and `false_fn` do not return the same number of + tensors, or return tensors of different types. Example: @@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None): ``` """ - with ops.name_scope(name, "cond", [pred]) as name: - if not callable(fn1): - raise TypeError("fn1 must be callable.") - if not callable(fn2): - raise TypeError("fn2 must be callable.") + # We needed to make true_fn/false_fn keyword arguments for + # backwards-compatibility. This check exists so that we can convert back to + # having them be positional arguments. + # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after + # `fn1` and `fn2` are deleted. + if fn1 is not None: + if true_fn is not None: + raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") + true_fn = fn1 + elif true_fn is None: + raise TypeError("cond(): true_fn argument required") + if fn2 is not None: + if false_fn is not None: + raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") + false_fn = fn2 + elif false_fn is None: + raise TypeError("cond(): false_fn argument required") + if not callable(true_fn): + raise TypeError("true_fn must be callable.") + if not callable(false_fn): + raise TypeError("false_fn must be callable.") + + with ops.name_scope(name, "cond", [pred]) as name: # Add the Switch to the graph. if isinstance(pred, bool): raise TypeError("pred must not be a Python bool") @@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None): # Build the graph for the true branch in a new context. context_t = CondContext(pred, pivot_1, branch=1) context_t.Enter() - orig_res_t, res_t = context_t.BuildCondBranch(fn1) + orig_res_t, res_t = context_t.BuildCondBranch(true_fn) if orig_res_t is None: - raise ValueError("fn1 must have a return value.") + raise ValueError("true_fn must have a return value.") context_t.ExitResult(res_t) context_t.Exit() # Build the graph for the false branch in a new context. context_f = CondContext(pred, pivot_2, branch=0) context_f.Enter() - orig_res_f, res_f = context_f.BuildCondBranch(fn2) + orig_res_f, res_f = context_f.BuildCondBranch(false_fn) if orig_res_f is None: - raise ValueError("fn2 must have a return value.") + raise ValueError("false_fn must have a return value.") context_f.ExitResult(res_f) context_f.Exit() @@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None): nest.assert_same_structure(orig_res_t, orig_res_f) except TypeError as e: raise TypeError( - "Incompatible return types of fn1 and fn2: {}".format(e)) + "Incompatible return types of true_fn and false_fn: {}".format(e)) except ValueError as e: raise ValueError( - "Incompatible return values of fn1 and fn2: {}".format(e)) + "Incompatible return values of true_fn and false_fn: {}".format(e)) # Add the final merge to the graph. if not res_t: - raise ValueError("fn1 and fn2 must return at least one result.") + raise ValueError("true_fn and false_fn must return at least one result.") res_t_flat = nest.flatten(res_t) res_f_flat = nest.flatten(res_f) @@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None): val_x = x if isinstance(x, ops.Tensor) else x.values val_y = y if isinstance(y, ops.Tensor) else y.values if val_x.dtype.base_dtype != val_y.dtype.base_dtype: - raise ValueError("Outputs of fn1 and fn2 must have the same type: " - "%s, %s" % (val_x.dtype.name, val_y.dtype.name)) + raise ValueError( + "Outputs of true_fn and false_fn must have the same type: %s, %s" % + (val_x.dtype.name, val_y.dtype.name)) merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges) @@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None): if not strict: merges = _UnpackIfSingleton(merges) return merges +# pylint: enable=g-doc-args def _resource_safe_shape(t): @@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, `cond` and `body`. `cond` and `body` both take as many arguments as there are `loop_vars`. - While `cond` evaluates to true, `body` is executed. - In addition to regular Tensors or IndexedSlices, the body may accept and return TensorArray objects. The flows of the TensorArray objects will be appropriately forwarded between loops and during gradient calculations. + Note that `while_loop` calls `cond` and `body` *exactly once* (inside the + call to `while_loop`, and not at all during `Session.run()`). `while_loop` + stitches together the graph fragments created during the `cond` and `body` + calls with some additional graph nodes to make something the repeats + `body` until `cond` returns false. + For correctness, `tf.while_loop()` strictly enforces shape invariants for the loop variables. A shape invariant is a (possibly partial) shape that is unchanged across the iterations of the loop. An error will be raised @@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"): operation returns the tensors generated by `default`. `tf.case` supports nested structures as implemented in - `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same + `tensorflow.python.util.nest`. All of the callables must return the same (possibly nested) value structure of lists, tuples, and/or named tuples. Singleton lists and tuples form the only exceptions to this: when returned by - `fn1` and/or `fn2`, they are implicitly unpacked to single values. This + a callable, they are implicitly unpacked to single values. This behavior is disabled by passing `strict=True`. Example 1: @@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"): Expressions: ``` - x = tf.constant(0) - y = tf.constant(1) - z = tf.constant(2) def f1(): return tf.constant(17) def f2(): return tf.constant(23) def f3(): return tf.constant(-1) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index 7704254b013..4e95783e5a8 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase): self.assertEquals(grad_x_false.eval(), 0.) +class CondTest(TensorFlowTestCase): + + def testCondTrue(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.cond( + math_ops.less(x, y), lambda: math_ops.multiply(x, 17), + lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 34) + + def testCondFalse(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(1) + z = control_flow_ops.cond( + math_ops.less(x, y), lambda: math_ops.multiply(x, 17), + lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 24) + + def testCondTrueLegacy(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = control_flow_ops.cond( + math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), + fn2=lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 34) + + def testCondFalseLegacy(self): + with self.test_session(): + x = constant_op.constant(2) + y = constant_op.constant(1) + z = control_flow_ops.cond( + math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), + fn2=lambda: math_ops.add(y, 23)) + self.assertEquals(z.eval(), 24) + + def testCondMissingArg1(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, false_fn=lambda: x) + + def testCondMissingArg2(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x) + + def testCondDuplicateArg1(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) + + def testCondDuplicateArg2(self): + with self.test_session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) + + class ContextTest(TensorFlowTestCase): def testCondContext(self): diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index c3c4145763b..fb636d9525f 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -718,7 +718,7 @@ tf_module { } member_method { name: "cond" - argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " + argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], " } member_method { name: "confusion_matrix"