Support batchable CompositeTensors as inputs to vectorized_map
.
PiperOrigin-RevId: 348498892 Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70
This commit is contained in:
parent
b8a634fd5a
commit
f0db26599f
@ -218,7 +218,7 @@ def _should_expand_composite(value):
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def _composite_to_tensors(value):
|
||||
def _composite_to_tensors(value, is_batched=False):
|
||||
"""Converts a CompositeTensor into a list of stackable tensors."""
|
||||
if _should_expand_composite(value):
|
||||
spec = value._type_spec
|
||||
@ -227,6 +227,8 @@ def _composite_to_tensors(value):
|
||||
"parallel_for or vectorized_map loop body must provide "
|
||||
"a `BatchableTypeSpec` (saw: {}).".format(
|
||||
value, spec))
|
||||
if is_batched:
|
||||
return spec._to_batched_tensor_list(value)
|
||||
return spec._to_tensor_list(value)
|
||||
return value
|
||||
# pylint: enable=protected-access
|
||||
@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
|
||||
return result
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def _gather_from_tensor_or_composite(x, i):
|
||||
"""Wrapper for gather that handles CompositeTensors."""
|
||||
if _should_expand_composite(x):
|
||||
spec = x._type_spec
|
||||
gathered_tensors = [_broadcasting_gather(t, i)
|
||||
for t in spec._to_batched_tensor_list(x)]
|
||||
return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
|
||||
return _broadcasting_gather(x, i)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
@tf_export("vectorized_map")
|
||||
def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
||||
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This method works similar to `tf.map_fn` but is optimized to run much faster,
|
||||
possibly with a much larger memory footprint. The speedups are obtained by
|
||||
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
|
||||
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
|
||||
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
|
||||
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
|
||||
behind vectorization is to semantically launch all the invocations of `fn` in
|
||||
parallel and fuse corresponding operations across all these invocations. This
|
||||
fusion is done statically at graph generation time and the generated code is
|
||||
@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
||||
Raises:
|
||||
ValueError: If vectorization fails and fallback_to_while_loop is False.
|
||||
"""
|
||||
def _convert_to_tensor_or_ndarray(x):
|
||||
if isinstance(x, np_arrays.ndarray):
|
||||
return x
|
||||
return ops.convert_to_tensor(x)
|
||||
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
|
||||
elems = nest.map_structure(ops.convert_to_tensor,
|
||||
elems,
|
||||
expand_composites=True)
|
||||
|
||||
def loop_fn(i):
|
||||
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
|
||||
elems)
|
||||
gathered_elems = nest.map_structure(
|
||||
lambda x: _gather_from_tensor_or_composite(x, i), elems)
|
||||
return fn(gathered_elems)
|
||||
|
||||
# Extract batch size from the maximum first dimension of any element.
|
||||
flat_elems = nest.flatten(elems)
|
||||
flat_elems = nest.flatten(
|
||||
nest.map_structure(
|
||||
functools.partial(_composite_to_tensors,
|
||||
is_batched=True),
|
||||
elems))
|
||||
def _get_shape(x):
|
||||
if isinstance(x, np_arrays.ndarray):
|
||||
x = x.data
|
||||
|
@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
|
||||
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.signal import fft_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
|
||||
self.assertTrue(particles.mass.shape, [4, 1, 3])
|
||||
self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
|
||||
|
||||
def test_vectorized_map_gathers_composite_tensors(self):
|
||||
particles = Particle(mass=[1., 2., 3., 4., 5.],
|
||||
velocity=[1., 2., 3., 4., 5.])
|
||||
self.assertAllEqual(
|
||||
pfor_control_flow_ops.vectorized_map(
|
||||
lambda x: x.mass * x.velocity, particles),
|
||||
particles.mass * particles.velocity)
|
||||
|
||||
def test_vectorized_map_of_ragged_tensors(self):
|
||||
# Vmap should be able to handle ragged Tensors as long as they're not
|
||||
# *actually* ragged.
|
||||
ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||
ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
row_lengths=[3, 3, 3, 3]),
|
||||
uniform_row_length=2) # Overall shape [2, 2, 3].
|
||||
self.assertAllEqual(
|
||||
pfor_control_flow_ops.vectorized_map(
|
||||
lambda x: x.to_tensor(shape=[2, 3]), ragged),
|
||||
ragged.to_tensor(shape=[2, 2, 3]))
|
||||
|
||||
|
||||
class ParsingTest(PForTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user