From 7c561e09c05100fe68f00d66a2d27d1b490ee74e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 1 May 2017 15:11:37 -0800
Subject: [PATCH] Explain when callables passed to tf.cond & tf.while_loop are
 run. Rename the parameters to tf.cond. Change: 154774725

---
 tensorflow/python/ops/control_flow_ops.py     | 111 ++++++++++++------
 .../python/ops/control_flow_ops_test.py       |  63 ++++++++++
 tensorflow/tools/api/golden/tensorflow.pbtxt  |   2 +-
 3 files changed, 137 insertions(+), 39 deletions(-)

diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index a22e4819d51..91be9ddbd78 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -71,6 +71,7 @@ from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops.gen_control_flow_ops import *
 # pylint: enable=wildcard-import
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_should_use
 
@@ -1679,14 +1680,20 @@ def _UnpackIfSingleton(res):
     return res
 
 
-def cond(pred, fn1, fn2, strict=False, name=None):
-  """Return `fn1()` if the boolean predicate `pred` is true else `fn2()`.
+# pylint: disable=g-doc-args
+@deprecation.deprecated_args(
+    None,
+    "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
+    "fn1", "fn2")
+def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
+         fn1=None, fn2=None):
+  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
 
-  `fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
-  the same non-zero number and type of outputs.
+  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
+  `false_fn` must have the same non-zero number and type of outputs.
 
   Note that the conditional execution applies only to the operations defined in
-  `fn1` and `fn2`. Consider the following simple program:
+  `true_fn` and `false_fn`. Consider the following simple program:
 
   ```python
   z = tf.multiply(a, b)
@@ -1700,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
   Although this behavior is consistent with the dataflow model of TensorFlow,
   it has occasionally surprised some users who expected a lazier semantics.
 
+  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
+  call to `cond`, and not at all during `Session.run()`). `cond`
+  stitches together the graph fragments created during the `true_fn` and
+  `false_fn` calls with some additional graph nodes to ensure that the right
+  branch gets executed depending on the value of `pred`.
+
   `tf.cond` supports nested structures as implemented in
-  `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
-  (possibly nested) value structure of lists, tuples, and/or named tuples.
+  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
+  same (possibly nested) value structure of lists, tuples, and/or named tuples.
   Singleton lists and tuples form the only exceptions to this: when returned by
-  `fn1` and/or `fn2`, they are implicitly unpacked to single values. This
-  behavior is disabled by passing `strict=True`.
+  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
+  This behavior is disabled by passing `strict=True`.
 
   Args:
-    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
-    fn1: The callable to be performed if pred is true.
-    fn2: The callable to be performed if pred is false.
+    pred: A scalar determining whether to return the result of `true_fn` or
+      `false_fn`.
+    true_fn: The callable to be performed if pred is true.
+    false_fn: The callable to be performed if pred is false.
     strict: A boolean that enables/disables 'strict' mode; see above.
     name: Optional name prefix for the returned tensors.
 
   Returns:
-    Tensors returned by the call to either `fn1` or `fn2`. If the callables
-    return a singleton list, the element is extracted from the list.
+    Tensors returned by the call to either `true_fn` or `false_fn`. If the
+    callables return a singleton list, the element is extracted from the list.
 
   Raises:
-    TypeError: if `fn1` or `fn2` is not callable.
-    ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
-                return tensors of different types.
+    TypeError: if `true_fn` or `false_fn` is not callable.
+    ValueError: if `true_fn` and `false_fn` do not return the same number of
+      tensors, or return tensors of different types.
 
   Example:
 
@@ -1736,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
   ```
 
   """
-  with ops.name_scope(name, "cond", [pred]) as name:
-    if not callable(fn1):
-      raise TypeError("fn1 must be callable.")
-    if not callable(fn2):
-      raise TypeError("fn2 must be callable.")
+  # We needed to make true_fn/false_fn keyword arguments for
+  # backwards-compatibility. This check exists so that we can convert back to
+  # having them be positional arguments.
+  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
+  # `fn1` and `fn2` are deleted.
+  if fn1 is not None:
+    if true_fn is not None:
+      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
+    true_fn = fn1
+  elif true_fn is None:
+    raise TypeError("cond(): true_fn argument required")
+  if fn2 is not None:
+    if false_fn is not None:
+      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
+    false_fn = fn2
+  elif false_fn is None:
+    raise TypeError("cond(): false_fn argument required")
 
+  if not callable(true_fn):
+    raise TypeError("true_fn must be callable.")
+  if not callable(false_fn):
+    raise TypeError("false_fn must be callable.")
+
+  with ops.name_scope(name, "cond", [pred]) as name:
     # Add the Switch to the graph.
     if isinstance(pred, bool):
       raise TypeError("pred must not be a Python bool")
@@ -1756,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
     # Build the graph for the true branch in a new context.
     context_t = CondContext(pred, pivot_1, branch=1)
     context_t.Enter()
-    orig_res_t, res_t = context_t.BuildCondBranch(fn1)
+    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
     if orig_res_t is None:
-      raise ValueError("fn1 must have a return value.")
+      raise ValueError("true_fn must have a return value.")
     context_t.ExitResult(res_t)
     context_t.Exit()
 
     # Build the graph for the false branch in a new context.
     context_f = CondContext(pred, pivot_2, branch=0)
     context_f.Enter()
-    orig_res_f, res_f = context_f.BuildCondBranch(fn2)
+    orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
     if orig_res_f is None:
-      raise ValueError("fn2 must have a return value.")
+      raise ValueError("false_fn must have a return value.")
     context_f.ExitResult(res_f)
     context_f.Exit()
 
@@ -1780,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
       nest.assert_same_structure(orig_res_t, orig_res_f)
     except TypeError as e:
       raise TypeError(
-          "Incompatible return types of fn1 and fn2: {}".format(e))
+          "Incompatible return types of true_fn and false_fn: {}".format(e))
     except ValueError as e:
       raise ValueError(
-          "Incompatible return values of fn1 and fn2: {}".format(e))
+          "Incompatible return values of true_fn and false_fn: {}".format(e))
 
     # Add the final merge to the graph.
     if not res_t:
-      raise ValueError("fn1 and fn2 must return at least one result.")
+      raise ValueError("true_fn and false_fn must return at least one result.")
 
     res_t_flat = nest.flatten(res_t)
     res_f_flat = nest.flatten(res_f)
@@ -1801,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
       val_x = x if isinstance(x, ops.Tensor) else x.values
       val_y = y if isinstance(y, ops.Tensor) else y.values
       if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
-        raise ValueError("Outputs of fn1 and fn2 must have the same type: "
-                         "%s, %s" % (val_x.dtype.name, val_y.dtype.name))
+        raise ValueError(
+            "Outputs of true_fn and false_fn must have the same type: %s, %s" %
+            (val_x.dtype.name, val_y.dtype.name))
 
     merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
     merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
@@ -1817,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
     if not strict:
       merges = _UnpackIfSingleton(merges)
     return merges
+# pylint: enable=g-doc-args
 
 
 def _resource_safe_shape(t):
@@ -2548,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
   `cond` and `body`. `cond` and `body` both take as many arguments as there are
   `loop_vars`.
 
-  While `cond` evaluates to true, `body` is executed.
-
   In addition to regular Tensors or IndexedSlices, the body may accept and
   return TensorArray objects.  The flows of the TensorArray objects will
   be appropriately forwarded between loops and during gradient calculations.
 
+  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
+  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
+  stitches together the graph fragments created during the `cond` and `body`
+  calls with some additional graph nodes to make something the repeats
+  `body` until `cond` returns false.
+
   For correctness, `tf.while_loop()` strictly enforces shape invariants for
   the loop variables. A shape invariant is a (possibly partial) shape that
   is unchanged across the iterations of the loop. An error will be raised
@@ -2882,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
   operation returns the tensors generated by `default`.
 
   `tf.case` supports nested structures as implemented in
-  `tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
+  `tensorflow.python.util.nest`. All of the callables must return the same
   (possibly nested) value structure of lists, tuples, and/or named tuples.
   Singleton lists and tuples form the only exceptions to this: when returned by
-  `fn1` and/or `fn2`, they are implicitly unpacked to single values. This
+  a callable, they are implicitly unpacked to single values. This
   behavior is disabled by passing `strict=True`.
 
   Example 1:
@@ -2913,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
 
     Expressions:
     ```
-      x = tf.constant(0)
-      y = tf.constant(1)
-      z = tf.constant(2)
       def f1(): return tf.constant(17)
       def f2(): return tf.constant(23)
       def f3(): return tf.constant(-1)
diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py
index 7704254b013..4e95783e5a8 100644
--- a/tensorflow/python/ops/control_flow_ops_test.py
+++ b/tensorflow/python/ops/control_flow_ops_test.py
@@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
       self.assertEquals(grad_x_false.eval(), 0.)
 
 
+class CondTest(TensorFlowTestCase):
+
+  def testCondTrue(self):
+    with self.test_session():
+      x = constant_op.constant(2)
+      y = constant_op.constant(5)
+      z = control_flow_ops.cond(
+          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
+          lambda: math_ops.add(y, 23))
+      self.assertEquals(z.eval(), 34)
+
+  def testCondFalse(self):
+    with self.test_session():
+      x = constant_op.constant(2)
+      y = constant_op.constant(1)
+      z = control_flow_ops.cond(
+          math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
+          lambda: math_ops.add(y, 23))
+      self.assertEquals(z.eval(), 24)
+
+  def testCondTrueLegacy(self):
+    with self.test_session():
+      x = constant_op.constant(2)
+      y = constant_op.constant(5)
+      z = control_flow_ops.cond(
+          math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
+          fn2=lambda: math_ops.add(y, 23))
+      self.assertEquals(z.eval(), 34)
+
+  def testCondFalseLegacy(self):
+    with self.test_session():
+      x = constant_op.constant(2)
+      y = constant_op.constant(1)
+      z = control_flow_ops.cond(
+          math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
+          fn2=lambda: math_ops.add(y, 23))
+      self.assertEquals(z.eval(), 24)
+
+  def testCondMissingArg1(self):
+    with self.test_session():
+      x = constant_op.constant(1)
+      with self.assertRaises(TypeError):
+        control_flow_ops.cond(True, false_fn=lambda: x)
+
+  def testCondMissingArg2(self):
+    with self.test_session():
+      x = constant_op.constant(1)
+      with self.assertRaises(TypeError):
+        control_flow_ops.cond(True, lambda: x)
+
+  def testCondDuplicateArg1(self):
+    with self.test_session():
+      x = constant_op.constant(1)
+      with self.assertRaises(TypeError):
+        control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
+
+  def testCondDuplicateArg2(self):
+    with self.test_session():
+      x = constant_op.constant(1)
+      with self.assertRaises(TypeError):
+        control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
+
+
 class ContextTest(TensorFlowTestCase):
 
   def testCondContext(self):
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index c3c4145763b..fb636d9525f 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -718,7 +718,7 @@ tf_module {
   }
   member_method {
     name: "cond"
-    argspec: "args=[\'pred\', \'fn1\', \'fn2\', \'strict\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+    argspec: "args=[\'pred\', \'true_fn\', \'false_fn\', \'strict\', \'name\', \'fn1\', \'fn2\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "confusion_matrix"