Support batchable CompositeTensors as inputs to vectorized_map.

PiperOrigin-RevId: 348498892
Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70
This commit is contained in:
Dave Moore 2020-12-21 11:28:31 -08:00 committed by TensorFlower Gardener
parent b8a634fd5a
commit f0db26599f
2 changed files with 49 additions and 11 deletions

View File

@ -218,7 +218,7 @@ def _should_expand_composite(value):
# pylint: disable=protected-access # pylint: disable=protected-access
def _composite_to_tensors(value): def _composite_to_tensors(value, is_batched=False):
"""Converts a CompositeTensor into a list of stackable tensors.""" """Converts a CompositeTensor into a list of stackable tensors."""
if _should_expand_composite(value): if _should_expand_composite(value):
spec = value._type_spec spec = value._type_spec
@ -227,6 +227,8 @@ def _composite_to_tensors(value):
"parallel_for or vectorized_map loop body must provide " "parallel_for or vectorized_map loop body must provide "
"a `BatchableTypeSpec` (saw: {}).".format( "a `BatchableTypeSpec` (saw: {}).".format(
value, spec)) value, spec))
if is_batched:
return spec._to_batched_tensor_list(value)
return spec._to_tensor_list(value) return spec._to_tensor_list(value)
return value return value
# pylint: enable=protected-access # pylint: enable=protected-access
@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
return result return result
# pylint: disable=protected-access
def _gather_from_tensor_or_composite(x, i):
"""Wrapper for gather that handles CompositeTensors."""
if _should_expand_composite(x):
spec = x._type_spec
gathered_tensors = [_broadcasting_gather(t, i)
for t in spec._to_batched_tensor_list(x)]
return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
return _broadcasting_gather(x, i)
# pylint: enable=protected-access
@tf_export("vectorized_map") @tf_export("vectorized_map")
def vectorized_map(fn, elems, fallback_to_while_loop=True): def vectorized_map(fn, elems, fallback_to_while_loop=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0. """Parallel map on the list of tensors unpacked from `elems` on dimension 0.
This method works similar to `tf.map_fn` but is optimized to run much faster, This method works similar to `tf.map_fn` but is optimized to run much faster,
possibly with a much larger memory footprint. The speedups are obtained by possibly with a much larger memory footprint. The speedups are obtained by
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
behind vectorization is to semantically launch all the invocations of `fn` in behind vectorization is to semantically launch all the invocations of `fn` in
parallel and fuse corresponding operations across all these invocations. This parallel and fuse corresponding operations across all these invocations. This
fusion is done statically at graph generation time and the generated code is fusion is done statically at graph generation time and the generated code is
@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
Raises: Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False. ValueError: If vectorization fails and fallback_to_while_loop is False.
""" """
def _convert_to_tensor_or_ndarray(x): elems = nest.map_structure(ops.convert_to_tensor,
if isinstance(x, np_arrays.ndarray): elems,
return x expand_composites=True)
return ops.convert_to_tensor(x)
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
def loop_fn(i): def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i), gathered_elems = nest.map_structure(
elems) lambda x: _gather_from_tensor_or_composite(x, i), elems)
return fn(gathered_elems) return fn(gathered_elems)
# Extract batch size from the maximum first dimension of any element. # Extract batch size from the maximum first dimension of any element.
flat_elems = nest.flatten(elems) flat_elems = nest.flatten(
nest.map_structure(
functools.partial(_composite_to_tensors,
is_batched=True),
elems))
def _get_shape(x): def _get_shape(x):
if isinstance(x, np_arrays.ndarray): if isinstance(x, np_arrays.ndarray):
x = x.data x = x.data

View File

@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.signal import fft_ops from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
self.assertTrue(particles.mass.shape, [4, 1, 3]) self.assertTrue(particles.mass.shape, [4, 1, 3])
self.assertAllEqual(particles.velocity.shape, [4, 5, 3]) self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
def test_vectorized_map_gathers_composite_tensors(self):
particles = Particle(mass=[1., 2., 3., 4., 5.],
velocity=[1., 2., 3., 4., 5.])
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.mass * x.velocity, particles),
particles.mass * particles.velocity)
def test_vectorized_map_of_ragged_tensors(self):
# Vmap should be able to handle ragged Tensors as long as they're not
# *actually* ragged.
ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
ragged_tensor.RaggedTensor.from_row_lengths(
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
row_lengths=[3, 3, 3, 3]),
uniform_row_length=2) # Overall shape [2, 2, 3].
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.to_tensor(shape=[2, 3]), ragged),
ragged.to_tensor(shape=[2, 2, 3]))
class ParsingTest(PForTestCase): class ParsingTest(PForTestCase):