From f0db26599f0fe687c325b5b7f8d670aee681a6ab Mon Sep 17 00:00:00 2001
From: Dave Moore <davmre@google.com>
Date: Mon, 21 Dec 2020 11:28:31 -0800
Subject: [PATCH] Support batchable CompositeTensors as inputs to
 `vectorized_map`.

PiperOrigin-RevId: 348498892
Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70
---
 .../ops/parallel_for/control_flow_ops.py      | 38 +++++++++++++------
 .../ops/parallel_for/control_flow_ops_test.py | 22 +++++++++++
 2 files changed, 49 insertions(+), 11 deletions(-)

diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py
index 3ab99636acb..169eb17cda1 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py
@@ -218,7 +218,7 @@ def _should_expand_composite(value):
 
 
 # pylint: disable=protected-access
-def _composite_to_tensors(value):
+def _composite_to_tensors(value, is_batched=False):
   """Converts a CompositeTensor into a list of stackable tensors."""
   if _should_expand_composite(value):
     spec = value._type_spec
@@ -227,6 +227,8 @@ def _composite_to_tensors(value):
                        "parallel_for or vectorized_map loop body must provide "
                        "a `BatchableTypeSpec` (saw: {}).".format(
                            value, spec))
+    if is_batched:
+      return spec._to_batched_tensor_list(value)
     return spec._to_tensor_list(value)
   return value
 # pylint: enable=protected-access
@@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
   return result
 
 
+# pylint: disable=protected-access
+def _gather_from_tensor_or_composite(x, i):
+  """Wrapper for gather that handles CompositeTensors."""
+  if _should_expand_composite(x):
+    spec = x._type_spec
+    gathered_tensors = [_broadcasting_gather(t, i)
+                        for t in spec._to_batched_tensor_list(x)]
+    return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
+  return _broadcasting_gather(x, i)
+# pylint: enable=protected-access
+
+
 @tf_export("vectorized_map")
 def vectorized_map(fn, elems, fallback_to_while_loop=True):
   """Parallel map on the list of tensors unpacked from `elems` on dimension 0.
 
   This method works similar to `tf.map_fn` but is optimized to run much faster,
   possibly with a much larger memory footprint. The speedups are obtained by
-  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, 
-  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea 
+  vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
+  Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
   behind vectorization is to semantically launch all the invocations of `fn` in
   parallel and fuse corresponding operations across all these invocations. This
   fusion is done statically at graph generation time and the generated code is
@@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
   Raises:
     ValueError: If vectorization fails and fallback_to_while_loop is False.
   """
-  def _convert_to_tensor_or_ndarray(x):
-    if isinstance(x, np_arrays.ndarray):
-      return x
-    return ops.convert_to_tensor(x)
-  elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
+  elems = nest.map_structure(ops.convert_to_tensor,
+                             elems,
+                             expand_composites=True)
 
   def loop_fn(i):
-    gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
-                                        elems)
+    gathered_elems = nest.map_structure(
+        lambda x: _gather_from_tensor_or_composite(x, i), elems)
     return fn(gathered_elems)
 
   # Extract batch size from the maximum first dimension of any element.
-  flat_elems = nest.flatten(elems)
+  flat_elems = nest.flatten(
+      nest.map_structure(
+          functools.partial(_composite_to_tensors,
+                            is_batched=True),
+          elems))
   def _get_shape(x):
     if isinstance(x, np_arrays.ndarray):
       x = x.data
diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
index f10d07f37c3..f27f952bb7f 100644
--- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
+++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py
@@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
 from tensorflow.python.ops.parallel_for.test_util import PForTestCase
+from tensorflow.python.ops.ragged import ragged_tensor
 from tensorflow.python.ops.signal import fft_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import nest
@@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
     self.assertTrue(particles.mass.shape, [4, 1, 3])
     self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
 
+  def test_vectorized_map_gathers_composite_tensors(self):
+    particles = Particle(mass=[1., 2., 3., 4., 5.],
+                         velocity=[1., 2., 3., 4., 5.])
+    self.assertAllEqual(
+        pfor_control_flow_ops.vectorized_map(
+            lambda x: x.mass * x.velocity, particles),
+        particles.mass * particles.velocity)
+
+  def test_vectorized_map_of_ragged_tensors(self):
+    # Vmap should be able to handle ragged Tensors as long as they're not
+    # *actually* ragged.
+    ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
+        ragged_tensor.RaggedTensor.from_row_lengths(
+            values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+            row_lengths=[3, 3, 3, 3]),
+        uniform_row_length=2)  # Overall shape [2, 2, 3].
+    self.assertAllEqual(
+        pfor_control_flow_ops.vectorized_map(
+            lambda x: x.to_tensor(shape=[2, 3]), ragged),
+        ragged.to_tensor(shape=[2, 2, 3]))
+
 
 class ParsingTest(PForTestCase):