From f0db26599f0fe687c325b5b7f8d670aee681a6ab Mon Sep 17 00:00:00 2001 From: Dave Moore <davmre@google.com> Date: Mon, 21 Dec 2020 11:28:31 -0800 Subject: [PATCH] Support batchable CompositeTensors as inputs to `vectorized_map`. PiperOrigin-RevId: 348498892 Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70 --- .../ops/parallel_for/control_flow_ops.py | 38 +++++++++++++------ .../ops/parallel_for/control_flow_ops_test.py | 22 +++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index 3ab99636acb..169eb17cda1 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -218,7 +218,7 @@ def _should_expand_composite(value): # pylint: disable=protected-access -def _composite_to_tensors(value): +def _composite_to_tensors(value, is_batched=False): """Converts a CompositeTensor into a list of stackable tensors.""" if _should_expand_composite(value): spec = value._type_spec @@ -227,6 +227,8 @@ def _composite_to_tensors(value): "parallel_for or vectorized_map loop body must provide " "a `BatchableTypeSpec` (saw: {}).".format( value, spec)) + if is_batched: + return spec._to_batched_tensor_list(value) return spec._to_tensor_list(value) return value # pylint: enable=protected-access @@ -421,14 +423,26 @@ def _broadcasting_gather(x, i): return result +# pylint: disable=protected-access +def _gather_from_tensor_or_composite(x, i): + """Wrapper for gather that handles CompositeTensors.""" + if _should_expand_composite(x): + spec = x._type_spec + gathered_tensors = [_broadcasting_gather(t, i) + for t in spec._to_batched_tensor_list(x)] + return spec._unbatch()._from_compatible_tensor_list(gathered_tensors) + return _broadcasting_gather(x, i) +# pylint: enable=protected-access + + @tf_export("vectorized_map") def vectorized_map(fn, elems, fallback_to_while_loop=True): """Parallel map on the list of tensors unpacked from `elems` on dimension 0. This method works similar to `tf.map_fn` but is optimized to run much faster, possibly with a much larger memory footprint. The speedups are obtained by - vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, - Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea + vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, + Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea behind vectorization is to semantically launch all the invocations of `fn` in parallel and fuse corresponding operations across all these invocations. This fusion is done statically at graph generation time and the generated code is @@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True): Raises: ValueError: If vectorization fails and fallback_to_while_loop is False. """ - def _convert_to_tensor_or_ndarray(x): - if isinstance(x, np_arrays.ndarray): - return x - return ops.convert_to_tensor(x) - elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems) + elems = nest.map_structure(ops.convert_to_tensor, + elems, + expand_composites=True) def loop_fn(i): - gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i), - elems) + gathered_elems = nest.map_structure( + lambda x: _gather_from_tensor_or_composite(x, i), elems) return fn(gathered_elems) # Extract batch size from the maximum first dimension of any element. - flat_elems = nest.flatten(elems) + flat_elems = nest.flatten( + nest.map_structure( + functools.partial(_composite_to_tensors, + is_batched=True), + elems)) def _get_shape(x): if isinstance(x, np_arrays.ndarray): x = x.data diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py index f10d07f37c3..f27f952bb7f 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops_test.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops_test.py @@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops from tensorflow.python.ops.parallel_for.test_util import PForTestCase +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.signal import fft_ops from tensorflow.python.platform import test from tensorflow.python.util import nest @@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase): self.assertTrue(particles.mass.shape, [4, 1, 3]) self.assertAllEqual(particles.velocity.shape, [4, 5, 3]) + def test_vectorized_map_gathers_composite_tensors(self): + particles = Particle(mass=[1., 2., 3., 4., 5.], + velocity=[1., 2., 3., 4., 5.]) + self.assertAllEqual( + pfor_control_flow_ops.vectorized_map( + lambda x: x.mass * x.velocity, particles), + particles.mass * particles.velocity) + + def test_vectorized_map_of_ragged_tensors(self): + # Vmap should be able to handle ragged Tensors as long as they're not + # *actually* ragged. + ragged = ragged_tensor.RaggedTensor.from_uniform_row_length( + ragged_tensor.RaggedTensor.from_row_lengths( + values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + row_lengths=[3, 3, 3, 3]), + uniform_row_length=2) # Overall shape [2, 2, 3]. + self.assertAllEqual( + pfor_control_flow_ops.vectorized_map( + lambda x: x.to_tensor(shape=[2, 3]), ragged), + ragged.to_tensor(shape=[2, 2, 3])) + class ParsingTest(PForTestCase):