From a74f9c3c612586ba4581bd6324a7c1ced69ec5a3 Mon Sep 17 00:00:00 2001
From: Edward Loper <edloper@google.com>
Date: Thu, 25 Apr 2019 12:30:56 -0700
Subject: [PATCH] Add support to for composite tensors, such as SparseTensor
 and RaggedTensor, to while_v2

PiperOrigin-RevId: 245285953
---
 tensorflow/python/framework/ops.py            |   4 +-
 tensorflow/python/framework/sparse_tensor.py  |   4 +-
 .../kernel_tests/control_flow_ops_py_test.py  | 140 +++++++++++-------
 tensorflow/python/ops/control_flow_ops.py     |   2 +-
 tensorflow/python/ops/while_v2.py             |  43 ++++--
 5 files changed, 118 insertions(+), 75 deletions(-)

diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index d23cfc77c94..f806a15a94a 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1761,10 +1761,10 @@ class IndexedSlices(_TensorLike, composite_tensor.CompositeTensor):
     if shape is None:
       shape = self._values.shape
     if self._dense_shape is None:
-      return [shape, shape[:1]]  # values, indices
+      return (shape, shape[:1])  # values, indices
     else:
       # values, indices, dense_shape
-      return [shape, shape[:1], tensor_shape.TensorShape([shape.ndims])]
+      return (shape, shape[:1], tensor_shape.TensorShape([shape.ndims]))
 
   @property
   def _is_graph_tensor(self):
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index e9199b1e661..cb427b4a4d2 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -250,11 +250,11 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
       raise ValueError("Shape invariant for SparseTensor must have the form "
                        "TensorShape([r]), got %r" % shape)
     rank = tensor_shape.dimension_value(shape[0])
-    return [
+    return (
         tensor_shape.TensorShape([None, rank]),  # indices
         tensor_shape.TensorShape([None]),  # values
         tensor_shape.TensorShape([rank])  # dense_shape
-    ]
+        )
 
   @property
   def _is_graph_tensor(self):
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 2f89ee4746c..43f57112bb0 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1790,6 +1790,18 @@ class ControlFlowTest(test.TestCase):
       r = r[1] * array_ops.ones([8, 8])
       self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
 
+  @test_util.disable_control_flow_v2("b/131265085")
+  @test_util.run_v1_only("b/131265085")
+  def testWhileBadShape(self):
+    x = constant_op.constant([2.0, 4.0], name="values")
+    i = constant_op.constant(0)
+    c = lambda i, _: math_ops.less(i, 10)
+    b = lambda i, x: [i + 1, x + 1]
+    with self.assertRaisesRegexp(ValueError, "is not compatible with"):
+      # Shape of x is [2], but we specify a shape of [5].
+      control_flow_ops.while_loop(
+          c, b, [i, x], [i.shape, tensor_shape.TensorShape([5])])
+
   @test_util.run_deprecated_v1
   def testWhileWithNonTensorInput_Scalar(self):
     with self.cached_session():
@@ -1807,7 +1819,6 @@ class ControlFlowTest(test.TestCase):
       r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
       self.assertEqual([10000], self.evaluate(r))
 
-  @test_util.run_v1_only("b/120545219")
   def testWhileShapeInference(self):
     with self.cached_session():
       i = constant_op.constant(0)
@@ -1822,19 +1833,23 @@ class ControlFlowTest(test.TestCase):
       r = control_flow_ops.while_loop(
           c, b, [i, m],
           [i.get_shape(), tensor_shape.TensorShape([None, 2])])
-      self.assertIsNone(r[1].shape.dims[0].value)
-      self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2))
+      self.assertTrue(r[1].shape.is_compatible_with([8, 2]))
 
+  @test_util.run_v1_only("b/120545219")
+  def testWhileShapeInferenceBadShape(self):
+    with self.cached_session():
+      i = constant_op.constant(0)
+      m = array_ops.ones([2, 2])
+      c = lambda i, j: math_ops.less(i, 2)
+      b = lambda i, j: [i + 1, array_ops.concat([j, j], 0)]
       with self.assertRaisesRegexp(
           ValueError,
           r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
           r"shape \(4, 2\) after one iteration. To allow the shape to vary "
           r"across iterations, use the `shape_invariants` argument of "
           r"tf.while_loop to specify a less-specific shape."):
-        r = control_flow_ops.while_loop(c, b, [i, m])
+        control_flow_ops.while_loop(c, b, [i, m])
 
-  @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
-  @test_util.run_v1_only("b/120545219")
   def testWhileShapeInferenceSparseTensor(self):
     values = constant_op.constant([2.0, 4.0], name="values")
     indices = constant_op.constant([[0], [3]],
@@ -1873,61 +1888,72 @@ class ControlFlowTest(test.TestCase):
               array_ops.concat([x.dense_shape, [10]], axis=0))
       ]
 
+    def check_shapes(r, indices, values, dense_shape):
+      self.assertTrue(r.indices.shape.is_compatible_with(indices))
+      self.assertTrue(r.values.shape.is_compatible_with(values))
+      self.assertTrue(r.dense_shape.shape.is_compatible_with(dense_shape))
+
     # Default shape invariant; b1 only modifies values.
     _, r = control_flow_ops.while_loop(c, b1, [i, x])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
+    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
 
     # Default shape invariant; b2 adds new values
     _, r = control_flow_ops.while_loop(c, b2, [i, x])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
-
-    # Default shape invariant; b3 modifies rank (which is not allowed).
-    with self.assertRaises(ValueError):
-      _, r = control_flow_ops.while_loop(c, b3, [i, x])
+    check_shapes(r, indices=[None, 1], values=[None], dense_shape=[1])
 
     # Explicit shape invariant, allowing any rank; b1 only modifies values.
     _, r = control_flow_ops.while_loop(
         c, b1, [i, x],
         [i.get_shape(), tensor_shape.TensorShape([None])])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
+    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
 
     # Explicit shape invariant, allowing any rank; b3 modifies rank.
     _, r = control_flow_ops.while_loop(
         c, b3, [i, x],
         [i.get_shape(), tensor_shape.TensorShape([None])])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
+    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
 
     # Shape invariant with ndims=None.  Technically, this isn't supported
     # according to the docs, but we support it for backwards compatibility.
     _, r = control_flow_ops.while_loop(
         c, b1, [i, x],
         [i.get_shape(), tensor_shape.TensorShape(None)])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
+    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
     _, r = control_flow_ops.while_loop(
         c, b3, [i, x],
         [i.get_shape(), tensor_shape.TensorShape(None)])
-    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
-    self.assertEqual(r.values.get_shape().as_list(), [None])
-    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
+    check_shapes(r, indices=[None, None], values=[None], dense_shape=[None])
+
+  @test_util.disable_control_flow_v2("b/131265085")
+  @test_util.run_v1_only("b/131265085")
+  def testWhileBadShapeSparseTensor(self):
+    values = constant_op.constant([2.0, 4.0], name="values")
+    indices = constant_op.constant([[0], [3]],
+                                   dtype=dtypes.int64,
+                                   name="indices")
+    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
+    i = constant_op.constant(0)
+    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
+    c = lambda i, _: i < 10
+    b1 = lambda i, x: [i+1, x]
+    def b2(i, x):  # modifies rank.  (shape of all components is changed.)
+      return [
+          i + 1,
+          sparse_tensor.SparseTensor(
+              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
+              array_ops.concat([x.dense_shape, [10]], axis=0))
+      ]
 
     # Explicit shape invariant, with a specific (incompatible) rank.
     with self.assertRaisesRegexp(ValueError, "is not compatible with"):
-      _, r = control_flow_ops.while_loop(
+      control_flow_ops.while_loop(
           c, b1, [i, x],
           [i.get_shape(), tensor_shape.TensorShape([5])])
 
-  @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
-  @test_util.run_v1_only("b/120545219")
+    # Default shape invariant, but b2 modifies rank (which is not allowed).
+    with self.assertRaises(ValueError):
+      control_flow_ops.while_loop(c, b2, [i, x])
+
   def testWhileShapeInferenceIndexedSlices(self):
     with self.cached_session():
       values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1953,17 +1979,28 @@ class ControlFlowTest(test.TestCase):
           c, b, [i, x],
           [i.get_shape(), tensor_shape.TensorShape([None, 2])])
       self.assertEqual(r.dense_shape.get_shape()[0], 2)
-      self.assertEqual(r.values.get_shape().as_list(), [None, 2])
+      self.assertTrue(r.values.get_shape().is_compatible_with([None, 2]))
 
-      with self.assertRaisesRegexp(ValueError, "is not compatible with"):
-        _, r = control_flow_ops.while_loop(
-            c, b, [i, x],
-            [i.get_shape(), tensor_shape.TensorShape([None, 5])])
+  @test_util.disable_control_flow_v2("b/131265085")
+  @test_util.run_v1_only("b/131265085")
+  def testWhileBadShapeIndexedSlices(self):
+    values = constant_op.constant([2.0, 4.0], name="values")
+    indices = constant_op.constant([[0], [3]],
+                                   dtype=dtypes.int64,
+                                   name="indices")
+    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
+    i = constant_op.constant(0)
+    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
+    c = lambda i, _: 10
+    b = lambda i, x: [i+1, x]
+
+    # Explicit shape invariant, with a specific (incompatible) rank.
+    with self.assertRaisesRegexp(ValueError, "is not compatible with"):
+      control_flow_ops.while_loop(
+          c, b, [i, x],
+          [i.get_shape(), tensor_shape.TensorShape([5])])
 
-  @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
   def testWhileShapeInferenceRaggedTensor(self):
-    if context.executing_eagerly():
-      self.skipTest("b/116328420")
     i = constant_op.constant(0)
     x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
     c = lambda i, _: i < 10
@@ -1980,11 +2017,13 @@ class ControlFlowTest(test.TestCase):
           array_ops.concat([x, x], axis=0)
       ]
 
+    def check_shapes(r, values, splits):
+      self.assertTrue(r.values.shape.is_compatible_with(values))
+      self.assertTrue(r.row_splits.shape.is_compatible_with(splits))
+
     # Default shape invariant; b1 adds new values to rows.
     _, r = control_flow_ops.while_loop(c, b1, [i, x])
-    self.assertEqual(r.row_splits.shape.as_list(), [4])
-
-    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
+    check_shapes(r, values=[None], splits=[4])
 
     # Default shape invariant; b2 adds new rows (not allowed).
     if not context.executing_eagerly():
@@ -1995,20 +2034,15 @@ class ControlFlowTest(test.TestCase):
     _, r = control_flow_ops.while_loop(
         c, b1, [i, x],
         [i.get_shape(), tensor_shape.TensorShape([None, None])])
-    self.assertTrue(r.row_splits.shape.as_list() in ([4], [None]))
-    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
+    check_shapes(r, values=[None], splits=[None])
 
     # Explicit shape invariant; b2 adds new rows.
     _, r = control_flow_ops.while_loop(
         c, b2, [i, x],
         [i.get_shape(), tensor_shape.TensorShape([None, None])])
-    self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None]))
-    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
+    check_shapes(r, values=[None], splits=[None])
 
-  @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
   def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
-    if context.executing_eagerly():
-      self.skipTest("b/116328420")
     i = constant_op.constant(0)
     x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
                                      [[], [8, 9, 10]]])
@@ -3473,8 +3507,7 @@ class ControlFlowTest(test.TestCase):
     self.assertEqual(0, value_x)
     self.assertEqual(73, value_x_grad)
 
-  @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
-  @test_util.run_v1_only("b/120545219")
+  @test_util.deprecated_graph_mode_only
   def testWhileGrad_IndexedSlices(self):
     with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
@@ -3496,8 +3529,7 @@ class ControlFlowTest(test.TestCase):
       r = gradients_impl.gradients(r.values, values)[0]
       self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
 
-  @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
-  @test_util.run_v1_only("b/120545219")
+  @test_util.deprecated_graph_mode_only
   def testWhileGrad_SparseTensor(self):
     with self.cached_session():
       values = constant_op.constant([2.0, 4.0], name="values")
@@ -3520,7 +3552,7 @@ class ControlFlowTest(test.TestCase):
       r = gradients_impl.gradients(r.values, values)[0]
       self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
 
-  @test_util.run_v1_only("b/120545219")
+  @test_util.deprecated_graph_mode_only
   def testCallGradInLoop(self):
     with self.cached_session() as sess:
       i0 = constant_op.constant(0)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index c1181a74dc6..daa99be060f 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -3466,7 +3466,7 @@ def while_loop(cond,
           return x
         return ops.convert_to_tensor(x)
 
-      loop_vars = nest.map_structure(convert, loop_vars)
+      loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
       if maximum_iterations is not None:
         return loop_vars[1]
       else:
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index cbbc0de9370..50e7d244d5e 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -72,12 +72,18 @@ def while_loop(cond,
   # `wrapped_body` below.
   loop_vars = list(_tensor_array_to_flow(orig_loop_vars))
   loop_vars = nest.map_structure(
-      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars)
+      ops.internal_convert_to_tensor_or_indexed_slices, loop_vars,
+      expand_composites=True)
   if shape_invariants is not None:
-    nest.assert_same_structure(orig_loop_vars, shape_invariants)
+    nest.assert_same_structure(orig_loop_vars, shape_invariants,
+                               expand_composites=False)
+    shape_invariants = nest.map_structure(
+        control_flow_ops._get_shape_invariant, loop_vars,
+        list(shape_invariants), expand_composites=False)
   else:
-    shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars)
-
+    shape_invariants = nest.map_structure(
+        control_flow_ops._get_shape_invariant, loop_vars,
+        expand_composites=False)
   if not name:
     name = "while"
 
@@ -150,11 +156,12 @@ def while_loop(cond,
       # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
       # and packs it into the structure of `orig_loop_vars`.
       outputs = body(*_pack_sequence_as(orig_loop_vars, args))
-      if not nest.is_sequence(outputs):
+      if not nest.is_sequence_or_composite(outputs):
         outputs = [outputs]
       # Compare the structure of input and output of body converting the
       # top-level tuples to list to be compatible with legacy while_loop.
-      nest.assert_same_structure(list(outputs), list(orig_loop_vars))
+      nest.assert_same_structure(list(outputs), list(orig_loop_vars),
+                                 expand_composites=True)
 
       outputs = _tensor_array_to_flow(outputs)
 
@@ -193,7 +200,8 @@ def while_loop(cond,
     # Make sure that the shapes of the loop outputs are compatible with the
     # shape invariants, or the shapes of the loop vars if the invariants are not
     # specified.
-    num_flattened_outputs = len(nest.flatten(orig_loop_vars))
+    num_flattened_outputs = len(nest.flatten(orig_loop_vars,
+                                             expand_composites=True))
     # First var is loop counter and second var is maximum_iterations.
     first_loop_var_index = 2
     _check_shapes_compat(
@@ -201,10 +209,10 @@ def while_loop(cond,
                            num_flattened_outputs],
         nest.flatten(
             shape_invariants[first_loop_var_index:first_loop_var_index +
-                             len_orig_loop_vars]),
+                             len_orig_loop_vars], expand_composites=True),
         nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index +
-                               len_orig_loop_vars]))
-    flattened_loop_vars = nest.flatten(loop_vars)
+                               len_orig_loop_vars], expand_composites=True))
+    flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True)
     _check_num_inputs_outputs(cond_graph, body_graph,
                               len(flattened_loop_vars))
 
@@ -237,7 +245,7 @@ def while_loop(cond,
   if return_same_structure:
     return outputs
 
-  flattened_outputs = nest.flatten(outputs)
+  flattened_outputs = nest.flatten(outputs, expand_composites=True)
   if len(flattened_outputs) == 1:
     return flattened_outputs[0]
   else:
@@ -905,9 +913,11 @@ def _pack_sequence_as(structure_with_tas, loop_vars):
 
   flattened_loop_vars = [
       flow_to_tensor_array(*z)
-      for z in zip(nest.flatten(loop_vars), nest.flatten(structure_with_tas))
+      for z in zip(nest.flatten(loop_vars, expand_composites=True),
+                   nest.flatten(structure_with_tas, expand_composites=True))
   ]
-  return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars)
+  return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars,
+                               expand_composites=True)
 
 
 def _tensor_array_to_flow(loop_vars):
@@ -917,14 +927,15 @@ def _tensor_array_to_flow(loop_vars):
       return maybe_ta.flow
     return maybe_ta
 
-  return nest.map_structure(f, loop_vars)
+  return nest.map_structure(f, loop_vars, expand_composites=True)
 
 
 def _build_signature(loop_vars, shape_invariants):
   return nest.pack_sequence_as(loop_vars, [
       tensor_spec.TensorSpec(s, t.dtype, name=t.op.name)
-      for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars))
-  ])
+      for s, t in zip(nest.flatten(shape_invariants, expand_composites=True),
+                      nest.flatten(loop_vars, expand_composites=True))
+  ], expand_composites=True)
 
 
 def _build_maximum_iterations_loop_var(maximum_iterations):