Support batchable CompositeTensors as inputs to vectorized_map
.
PiperOrigin-RevId: 348498892 Change-Id: I82ac5012dbf13705af584fb0350f14d170a3ff70
This commit is contained in:
parent
b8a634fd5a
commit
f0db26599f
@ -218,7 +218,7 @@ def _should_expand_composite(value):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def _composite_to_tensors(value):
|
def _composite_to_tensors(value, is_batched=False):
|
||||||
"""Converts a CompositeTensor into a list of stackable tensors."""
|
"""Converts a CompositeTensor into a list of stackable tensors."""
|
||||||
if _should_expand_composite(value):
|
if _should_expand_composite(value):
|
||||||
spec = value._type_spec
|
spec = value._type_spec
|
||||||
@ -227,6 +227,8 @@ def _composite_to_tensors(value):
|
|||||||
"parallel_for or vectorized_map loop body must provide "
|
"parallel_for or vectorized_map loop body must provide "
|
||||||
"a `BatchableTypeSpec` (saw: {}).".format(
|
"a `BatchableTypeSpec` (saw: {}).".format(
|
||||||
value, spec))
|
value, spec))
|
||||||
|
if is_batched:
|
||||||
|
return spec._to_batched_tensor_list(value)
|
||||||
return spec._to_tensor_list(value)
|
return spec._to_tensor_list(value)
|
||||||
return value
|
return value
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@ -421,14 +423,26 @@ def _broadcasting_gather(x, i):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
def _gather_from_tensor_or_composite(x, i):
|
||||||
|
"""Wrapper for gather that handles CompositeTensors."""
|
||||||
|
if _should_expand_composite(x):
|
||||||
|
spec = x._type_spec
|
||||||
|
gathered_tensors = [_broadcasting_gather(t, i)
|
||||||
|
for t in spec._to_batched_tensor_list(x)]
|
||||||
|
return spec._unbatch()._from_compatible_tensor_list(gathered_tensors)
|
||||||
|
return _broadcasting_gather(x, i)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
@tf_export("vectorized_map")
|
@tf_export("vectorized_map")
|
||||||
def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
||||||
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
|
"""Parallel map on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
This method works similar to `tf.map_fn` but is optimized to run much faster,
|
This method works similar to `tf.map_fn` but is optimized to run much faster,
|
||||||
possibly with a much larger memory footprint. The speedups are obtained by
|
possibly with a much larger memory footprint. The speedups are obtained by
|
||||||
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
|
vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians,
|
||||||
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
|
Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea
|
||||||
behind vectorization is to semantically launch all the invocations of `fn` in
|
behind vectorization is to semantically launch all the invocations of `fn` in
|
||||||
parallel and fuse corresponding operations across all these invocations. This
|
parallel and fuse corresponding operations across all these invocations. This
|
||||||
fusion is done statically at graph generation time and the generated code is
|
fusion is done statically at graph generation time and the generated code is
|
||||||
@ -518,19 +532,21 @@ def vectorized_map(fn, elems, fallback_to_while_loop=True):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If vectorization fails and fallback_to_while_loop is False.
|
ValueError: If vectorization fails and fallback_to_while_loop is False.
|
||||||
"""
|
"""
|
||||||
def _convert_to_tensor_or_ndarray(x):
|
elems = nest.map_structure(ops.convert_to_tensor,
|
||||||
if isinstance(x, np_arrays.ndarray):
|
elems,
|
||||||
return x
|
expand_composites=True)
|
||||||
return ops.convert_to_tensor(x)
|
|
||||||
elems = nest.map_structure(_convert_to_tensor_or_ndarray, elems)
|
|
||||||
|
|
||||||
def loop_fn(i):
|
def loop_fn(i):
|
||||||
gathered_elems = nest.map_structure(lambda x: _broadcasting_gather(x, i),
|
gathered_elems = nest.map_structure(
|
||||||
elems)
|
lambda x: _gather_from_tensor_or_composite(x, i), elems)
|
||||||
return fn(gathered_elems)
|
return fn(gathered_elems)
|
||||||
|
|
||||||
# Extract batch size from the maximum first dimension of any element.
|
# Extract batch size from the maximum first dimension of any element.
|
||||||
flat_elems = nest.flatten(elems)
|
flat_elems = nest.flatten(
|
||||||
|
nest.map_structure(
|
||||||
|
functools.partial(_composite_to_tensors,
|
||||||
|
is_batched=True),
|
||||||
|
elems))
|
||||||
def _get_shape(x):
|
def _get_shape(x):
|
||||||
if isinstance(x, np_arrays.ndarray):
|
if isinstance(x, np_arrays.ndarray):
|
||||||
x = x.data
|
x = x.data
|
||||||
|
@ -70,6 +70,7 @@ from tensorflow.python.ops import tensor_array_ops
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
|
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
|
||||||
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.ops.signal import fft_ops
|
from tensorflow.python.ops.signal import fft_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -2157,6 +2158,27 @@ class CompositeTensorTest(PForTestCase, parameterized.TestCase):
|
|||||||
self.assertTrue(particles.mass.shape, [4, 1, 3])
|
self.assertTrue(particles.mass.shape, [4, 1, 3])
|
||||||
self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
|
self.assertAllEqual(particles.velocity.shape, [4, 5, 3])
|
||||||
|
|
||||||
|
def test_vectorized_map_gathers_composite_tensors(self):
|
||||||
|
particles = Particle(mass=[1., 2., 3., 4., 5.],
|
||||||
|
velocity=[1., 2., 3., 4., 5.])
|
||||||
|
self.assertAllEqual(
|
||||||
|
pfor_control_flow_ops.vectorized_map(
|
||||||
|
lambda x: x.mass * x.velocity, particles),
|
||||||
|
particles.mass * particles.velocity)
|
||||||
|
|
||||||
|
def test_vectorized_map_of_ragged_tensors(self):
|
||||||
|
# Vmap should be able to handle ragged Tensors as long as they're not
|
||||||
|
# *actually* ragged.
|
||||||
|
ragged = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
||||||
|
ragged_tensor.RaggedTensor.from_row_lengths(
|
||||||
|
values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||||
|
row_lengths=[3, 3, 3, 3]),
|
||||||
|
uniform_row_length=2) # Overall shape [2, 2, 3].
|
||||||
|
self.assertAllEqual(
|
||||||
|
pfor_control_flow_ops.vectorized_map(
|
||||||
|
lambda x: x.to_tensor(shape=[2, 3]), ragged),
|
||||||
|
ragged.to_tensor(shape=[2, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
class ParsingTest(PForTestCase):
|
class ParsingTest(PForTestCase):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user