Support batchable CompositeTensors as inputs to vectorized_map.

PiperOrigin-RevId: 348498892
Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70
This commit is contained in:
Dave Moore 2020-12-21 11:28:31 -08:00 committed by TensorFlower Gardener
parent b8a634fd5a
commit f0db26599f
2 changed files with 49 additions and 11 deletions

View File

@ -218,7 +218,7 @@ def _should_expand_composite(value):
# pylint: disable=protected-access
def _composite_to_tensors(value):
def _composite_to_tensors(value, is_batched=False):
"""Converts a CompositeTensor into a list of stackable tensors."""
if _should_expand_composite(value):
spec = value._type_spec
@ -227,6 +227,8 @@ def _composite_to_tensors(value):
"parallel_for or vectorized_map loop body must provide "
"a `BatchableTypeSpec` (saw: {}).".format(
value, spec))
if is_batched:
return spec._to_batched_tensor_list(value)
return spec._to_tensor_list(value)
return value
# pylint: enable=protected-access
@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
return result
# pylint: disable=protected-access
def _gather_from_tensor_or_composite(x, i):
"""Wrapper for gather that handles CompositeTensors."""
if _should_expand_composite(x):
spec = x._type_spec
gathered_tensors = [_broadcasting_gather(t, i)
for t in spec._to_batched_tensor_list(x)]
return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
return _broadcasting_gather(x, i)
# pylint: enable=protected-access
@tf_export("vectorized_map")
def vectorized_map(fn, elems, fallback_to_while_loop=True):
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
This method works similar to `tf.map_fn` but is optimized to run much faster,
possibly with a much larger memory footprint. The speedups are obtained by
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
behind vectorization is to semantically launch all the invocations of `fn` in
parallel and fuse corresponding operations across all these invocations. This
fusion is done statically at graph generation time and the generated code is
@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
Raises:
ValueError: If vectorization fails and fallback_to_while_loop is False.
"""
def _convert_to_tensor_or_ndarray(x):
if isinstance(x, np_arrays.ndarray):
return x
return ops.convert_to_tensor(x)
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
elems = nest.map_structure(ops.convert_to_tensor,
elems,
expand_composites=True)
def loop_fn(i):
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
elems)
gathered_elems = nest.map_structure(
lambda x: _gather_from_tensor_or_composite(x, i), elems)
return fn(gathered_elems)
# Extract batch size from the maximum first dimension of any element.
flat_elems = nest.flatten(elems)
flat_elems = nest.flatten(
nest.map_structure(
functools.partial(_composite_to_tensors,
is_batched=True),
elems))
def _get_shape(x):
if isinstance(x, np_arrays.ndarray):
x = x.data

View File

@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.signal import fft_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
self.assertTrue(particles.mass.shape, [4, 1, 3])
self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
def test_vectorized_map_gathers_composite_tensors(self):
particles = Particle(mass=[1., 2., 3., 4., 5.],
velocity=[1., 2., 3., 4., 5.])
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.mass * x.velocity, particles),
particles.mass * particles.velocity)
def test_vectorized_map_of_ragged_tensors(self):
# Vmap should be able to handle ragged Tensors as long as they're not
# *actually* ragged.
ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
ragged_tensor.RaggedTensor.from_row_lengths(
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
row_lengths=[3, 3, 3, 3]),
uniform_row_length=2) # Overall shape [2, 2, 3].
self.assertAllEqual(
pfor_control_flow_ops.vectorized_map(
lambda x: x.to_tensor(shape=[2, 3]), ragged),
ragged.to_tensor(shape=[2, 2, 3]))
class ParsingTest(PForTestCase):