Update tf.map_fn to support RaggedTensors and SparseTensors.

PiperOrigin-RevId: 300134914
Change-Id: I84e75ea1d71257d34d045b82b54de208ea2fe75b
This commit is contained in:
Edward Loper 2020-03-10 11:16:21 -07:00 committed by TensorFlower Gardener
parent 3299b778c3
commit 0bbfe8d95d
7 changed files with 521 additions and 578 deletions

View File

@ -69,13 +69,14 @@ class MapFnTest(test.TestCase):
def testMapSparseTensor(self): def testMapSparseTensor(self):
with self.cached_session(): with self.cached_session():
with self.assertRaises(TypeError): st = sparse_tensor.SparseTensor(
map_fn.map_fn( indices=[[0, 0], [0, 1], [1, 0]],
lambda x: x, values=constant_op.constant([0, 1, 2]),
sparse_tensor.SparseTensor( dense_shape=[2, 2])
indices=[[0, 0], [0, 1], [1, 0]], result = map_fn.map_fn(lambda x: x, st)
values=constant_op.constant([0, 1, 2]), self.assertAllEqual(result.indices, st.indices)
dense_shape=[2, 2])) self.assertAllEqual(result.values, st.values)
self.assertAllEqual(result.dense_shape, st.dense_shape)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testMapOverScalarErrors(self): def testMapOverScalarErrors(self):

View File

@ -20,15 +20,20 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import re
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import nest from tensorflow.python.util import nest
@ -36,134 +41,324 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=["map_fn"]) @tf_export(v1=["map_fn"])
def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, @deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
swap_memory=False, infer_shape=True, name=None): def map_fn(fn,
"""map on the list of tensors unpacked from `elems` on dimension 0. elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None,
fn_output_signature=None):
"""Transforms `elems` by applying `fn` to each element unstacked on axis 0.
The simplest version of `map_fn` repeatedly applies the callable `fn` to a `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
sequence of elements from first to last. The elements are made of the calls `fn` to transform each element; and then stacks the transformed
tensors unpacked from `elems`. `dtype` is the data type of the return values back together.
value of `fn`. Users must provide `dtype` if it is different from
the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape #### Mapping functions with single-Tensor inputs and outputs
of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems` If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
is a (possibly nested) list or tuple of tensors, then each of these tensors then `map_fn(fn, elems)` is equivalent to
must have a matching first (unpack) dimension. The signature of `fn` may `tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.:
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
Furthermore, `fn` may emit a different structure than its input. For example, >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
`fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
the `dtype` parameter is not optional: `dtype` must be a type or (possibly array([[3, 4, 5],
nested) tuple of types matching the output of `fn`. [5, 6, 7],
[2, 3, 4]], dtype=int32)>
To apply a functional operation to the nonzero elements of a SparseTensor `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.
one of the following methods is recommended. First, if the function is
expressible as TensorFlow ops, use
```python #### Mapping functions with multi-arity inputs and outputs
result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
```
If, however, the function is not expressible as a TensorFlow op, then use `map_fn` also supports functions with multi-arity inputs and outputs:
```python * If `elems` is a tuple (or nested structure) of tensors, then those tensors
result = SparseTensor( must all have the same outer-dimension size (`num_elems`); and `fn` is
input.indices, map_fn(fn, input.values), input.dense_shape) used to transform each tuple (or structure) of corresponding slices from
``` `elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
transform each tuple of slices `(t1[i], t2[i], t3[i])`
(where `0 <= i < num_elems`).
instead. * If `fn` returns a tuple (or nested structure) of tensors, then the
result is formed by stacking corresponding elements from those structures.
When executing eagerly, map_fn does not execute in parallel even if #### Specifying `fn`'s output signature
If `fn`'s input and output signatures are different, then the output
signature must be specified using `fn_output_signature`. (The input and
output signatures are differ if their structures, dtypes, or tensor types do
not match). E.g.:
>>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
... elems=tf.constant(["hello", "moon"]),
... fn_output_signature=tf.int32)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
>>> tf.map_fn(fn=tf.strings.join, # input & output have different structures
... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
... fn_output_signature=tf.string)
<tf.Tensor: shape=(2,), dtype=string,
numpy=array([b'TheDog', b'ACat'], dtype=object)>
`fn_output_signature` can be specified using any of the following:
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
* A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
* A (possibly nested) tuple, list, or dict containing the above types.
#### RaggedTensors
`map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular:
* If `elems` is a `RaggedTensor`, then `fn` will be called with each
row of that ragged tensor.
* If `elems` has only one ragged dimension, then the values passed to
`fn` will be `tf.Tensor`s.
* If `elems` has multiple ragged dimensions, then the values passed to
`fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.
* If the result of `map_fn` should be a `RaggedTensor`, then use a
`tf.RaggedTensorSpec` to specify `fn_output_signature`.
* If `fn` returns `tf.Tensor`s with varying sizes, then use a
`tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
single ragged tensor (which will have ragged_rank=1).
* If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
with the same `ragged_rank`.
>>> # Example: RaggedTensor input
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
>>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
>>> # Example: RaggedTensor output
>>> elems = tf.constant([3, 5, 0, 2])
>>> tf.map_fn(tf.range, elems,
... fn_output_signature=tf.RaggedTensorSpec(shape=[None],
... dtype=tf.int32))
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
Note: `map_fn` should only be used if you need to map a function over the
*rows* of a `RaggedTensor`. If you wish to map a function over the
individual values, then you should use:
* `tf.ragged.map_flat_values(fn, rt)`
(if fn is expressible as TensorFlow ops)
* `rt.with_flat_values(map_fn(fn, rt.flat_values))`
(otherwise)
E.g.:
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
>>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
#### SparseTensors
`map_fn` supports `tf.SparseTensor` inputs and outputs. In particular:
* If `elems` is a `SparseTensor`, then `fn` will be called with each row
of that sparse tensor. In particular, the value passed to `fn` will be a
`tf.SparseTensor` with one fewer dimension than `elems`.
* If the result of `map_fn` should be a `SparseTensor`, then use a
`tf.SparseTensorSpec` to specify `fn_output_signature`. The individual
`SparseTensor`s returned by `fn` will be stacked into a single
`SparseTensor` with one more dimension.
>>> # Example: SparseTensor input
>>> st = tf.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
>>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
>>> # Example: SparseTensor output
>>> tf.sparse.to_dense(
... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
... fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
array([[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 0.]],
[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]], dtype=float32)>
Note: `map_fn` should only be used if you need to map a function over the
*rows* of a `SparseTensor`. If you wish to map a function over the nonzero
values, then you should use:
* `tf.SparseTensor(st.indices, fn(st.values), st.dense_shape)`
(if the function is expressible as TensorFlow ops)
* `tf.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)`
(otherwise).
#### `map_fn` vs. vectorized operations
`map_fn` will apply the operations used by `fn` to each element of `elems`,
resulting in `O(elems.shape[0])` total operations. This is somewhat
mitigated by the fact that `map_fn` can process elements in parallel.
However, a transform expressed using `map_fn` is still typically less
efficient than an equivalent transform expressed using vectorized operations.
`map_fn` should typically only be used if one of the following is true:
* It is difficult or expensive to express the desired transform with
vectorized operations.
* `fn` creates large intermediate values, so an equivalent vectorized
transform would take too much memory.
* Processing elements in parallel is more efficient than an equivalent
vectorized transform.
* Efficiency of the transform is not critical, and using `map_fn` is
more readable.
E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
across `elems` could be rewritten more efficiently using vectorized ops:
>>> elems = tf.constant([3, 5, 2])
>>> tf.range(3) + tf.expand_dims(elems, 1)
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
In some cases, `tf.vectorized_map` can be used to automatically convert a
function to a vectorized eqivalent.
#### Eager execution
When executing eagerly, `map_fn` does not execute in parallel even if
`parallel_iterations` is set to a value > 1. You can still get the `parallel_iterations` is set to a value > 1. You can still get the
performance benefits of running a function in parallel by using the performance benefits of running a function in parallel by using the
`tf.function` decorator, `tf.function` decorator:
>>> fn=lambda t: tf.range(t, t + 3)
>>> @tf.function
... def func(elems):
... return tf.map_fn(fn, elems, parallel_iterations=3)
>>> func(tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[3, 4, 5],
[5, 6, 7],
[2, 3, 4]], dtype=int32)>
```python
# Assume the function being used in map_fn is fn.
# To ensure map_fn calls fn in parallel, use the tf.function decorator.
@tf.function
def func(tensor):
return tf.map_fn(fn, tensor)
```
Note that if you use the `tf.function` decorator, any non-TensorFlow Python Note that if you use the `tf.function` decorator, any non-TensorFlow Python
code that you may have written in your function won't get executed. See code that you may have written in your function won't get executed. See
[`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for `tf.function` for more details. The recommendation would be to debug without
more details. The recommendation would be to debug without `tf.function` but `tf.function` but switch to it to get performance benefits of running `map_fn`
switch to it to get performance benefits of running `map_fn` in parallel. in parallel.
Args: Args:
fn: The callable to be performed. It accepts one argument, which will fn: The callable to be performed. It accepts one argument, which will have
have the same (possibly nested) structure as `elems`. Its output the same (possibly nested) structure as `elems`. Its output must have the
must have the same structure as `dtype` if one is provided, otherwise same structure as `fn_output_signature` if one is provided; otherwise it
it must have the same structure as `elems`. must have the same structure as `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which elems: A tensor or (possibly nested) sequence of tensors, each of which will
will be unpacked along their first dimension. The nested sequence be unstacked along their first dimension. `fn` will be applied to the
of the resulting slices will be applied to `fn`. nested sequence of the resulting slices. `elems` may include ragged and
dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure sparse tensors.
of Tensors differing from the structure of `elems`, then `dtype` is not dtype: Deprecated: Equivalent to `fn_output_signature`.
optional and must have the same structure as the output of `fn`. parallel_iterations: (optional) The number of iterations allowed to run in
parallel_iterations: (optional) The number of iterations allowed to run parallel. When graph building, the default value is 10. While executing
in parallel. When graph building, the default value is 10. While executing
eagerly, the default value is set to 1. eagerly, the default value is set to 1.
back_prop: (optional) True enables support for back propagation. back_prop: (optional) False disables support for back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping. swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes. infer_shape: (optional) False disables tests for consistent output shapes.
name: (optional) Name prefix for the returned tensors. name: (optional) Name prefix for the returned tensors.
fn_output_signature: The output signature of `fn`. Must be specified if
`fn`'s input and output signatures are different (i.e., if their
structures, dtypes, or tensor types do not match).
`fn_output_signature` can be specified using any of the following:
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
* A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
* A (possibly nested) tuple, list, or dict containing the above types.
Returns: Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor packs the A tensor or (possibly nested) sequence of tensors. Each tensor stacks the
results of applying `fn` to tensors unpacked from `elems` along the first results of applying `fn` to tensors unstacked from `elems` along the first
dimension, from first to last. dimension, from first to last. The result may include ragged and sparse
tensors.
Raises: Raises:
TypeError: if `fn` is not callable or the structure of the output of TypeError: if `fn` is not callable or the structure of the output of
`fn` and `dtype` do not match, or if elems is a SparseTensor. `fn` and `fn_output_signature` do not match.
ValueError: if the lengths of the output of `fn` and `dtype` do not match. ValueError: if the lengths of the output of `fn` and `fn_output_signature`
do not match.
Examples: Examples:
```python
elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
```python >>> elems = np.array([1, 2, 3, 4, 5, 6])
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) >>> tf.map_fn(lambda x: x * x, elems)
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) <tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])>
# alternate == [-1, 2, -3]
```
```python >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
elems = np.array([1, 2, 3]) >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) <tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])>
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3] >>> elems = np.array([1, 2, 3])
``` >>> tf.map_fn(lambda x: (x, -x), elems,
... fn_output_signature=(tf.int64, tf.int64))
(<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
""" """
# This function uses a `while_loop` to call `fn` on each value of the input
# tensor(s) (unstacked on dimension 0). The following sequence of variables
# are used to transform the input tensor(s) (`elems`) into the output
# tensor(s) (`result`):
#
# - Preparing and unstacking input values for the while_loop:
# - elems: The input tensor(s) to map_fn. May include composite tensors.
# - elems_flat: Flattened list of tensors from elems (using nest.flatten)
# May include composite tensors.
# - elems_batchable: Concatenation of "batchable tensor lists" for each
# tensor in elems_flat. This "boxes" composite tensors
# into sliceable tf.Tensor objects. For more info see:
# TensorSpec._to_batched_tensor_list
# - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
# in elems_batchable into elems_value_batchable.
#
# - Calling `fn` on each unstacked value in the body of the while_loop:
# - elems_value_batchable: Single unstacked value from elems_batchable.
# - elems_value_flat: Single unstacked value from elems_flat,
# constructed from elems_value_batchable (using
# TensorSpec._from_tensor_list).
# - elems_value: Single unstacked value from elems (the input to fn).
# - result_value: Result of calling `fn(elems_value)`. May contain
# composite tensors.
# - result_value_flat: Flattened list of tensors from result_value.
# May contain composite tensors.
# - result_value_batchable: Concatenation of batchable tensor lists for
# each tensor in result_value_flat
# (using TensorSpec._to_tensor_list).
#
# - Collecting and stacking output values from the while_loop:
# - result_batchable_ta: List of TensorArrays used to stack each tensor
# ta result_value_batchable into result_batchable.
# - result_batchable: Stacked tensors from result_batchable_ta.
# - result_flat: Flat list of tensors for the result, constructed from
# results bactchable (using TensorSpec._from_tensor_list).
# - result: Structured result value packed from results flat
# (using nest.pack_sequence_as).
if fn_output_signature is None:
fn_output_signature = dtype
if not callable(fn): if not callable(fn):
raise TypeError("fn must be callable.") raise TypeError("fn must be callable.")
if isinstance(elems, sparse_tensor.SparseTensor):
raise TypeError(
"To perform a map on the values of a sparse tensor use either "
" SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
" SparseTensor(input.indices, map_fn(fn, input.values), "
"input.dense_shape)")
in_graph_mode = not context.executing_eagerly() in_graph_mode = not context.executing_eagerly()
# Set the default number of parallel_iterations depending on graph/eager mode. # Set the default number of parallel_iterations depending on graph/eager mode.
if in_graph_mode and not parallel_iterations: if in_graph_mode and not parallel_iterations:
parallel_iterations = 10 parallel_iterations = 10
elif not in_graph_mode and not parallel_iterations: elif not in_graph_mode and not parallel_iterations:
parallel_iterations = 1 parallel_iterations = 1
elif not in_graph_mode and parallel_iterations > 1:
if not in_graph_mode and parallel_iterations > 1:
logging.log_first_n( logging.log_first_n(
logging.WARN, "Setting parallel_iterations > 1 has no " logging.WARN, "Setting parallel_iterations > 1 has no "
"effect when executing eagerly. Consider calling map_fn" "effect when executing eagerly. Consider calling map_fn"
@ -171,23 +366,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
"parallel.", 1) "parallel.", 1)
parallel_iterations = 1 parallel_iterations = 1
input_is_sequence = nest.is_sequence(elems) # Flatten the input tensors, and get the TypeSpec for each one.
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] elems_flat = nest.flatten(elems)
def input_pack(x): elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
if dtype is None: # Flatten fn's output signature.
output_is_sequence = input_is_sequence if fn_output_signature is None:
output_flatten = input_flatten # If fn_output_signature was not specified, then assume that it matches the
output_pack = input_pack # input signature.
result_flat_signature = [
_most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
for s in elems_flat_signature
]
result_unflatten = elems_unflatten
else: else:
output_is_sequence = nest.is_sequence(dtype) result_flat_signature = [
output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] _dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
def output_pack(x): ]
return (nest.pack_sequence_as(dtype, x) result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)
if output_is_sequence else x[0])
elems_flat = input_flatten(elems)
with ops.name_scope(name, "map", elems_flat): with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are # TODO(akshayka): Remove the in_graph_mode check once caching devices are
@ -204,42 +401,56 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
varscope_caching_device_was_none = True varscope_caching_device_was_none = True
elems_flat = [ elems_flat = [
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
]
dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) # Check that inputs are not scalars.
dtype_flat = output_flatten(dtype) elems_static_shape = elems_flat[0].shape
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
# Convert elems to tensor array. n may be known statically.
static_shape = elems_flat[0].shape
if static_shape.ndims is not None and static_shape.ndims < 1:
if len(elems_flat) == 1: if len(elems_flat) == 1:
raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
else: else:
raise ValueError( raise ValueError(
"elements in elems must be 1+ dimensional Tensors, not scalars" "elements in elems must be 1+ dimensional Tensors, not scalars"
) )
n = (tensor_shape.dimension_value(static_shape[0])
or array_ops.shape(elems_flat[0])[0])
# TensorArrays are always flat # Box any composite tensors into tensor lists.
elems_ta = [ elems_batchable = _elems_flat_to_batchable(elems_flat)
tensor_array_ops.TensorArray(dtype=elem.dtype,
size=n, # Find the number of iterations, n. (may be known statically.)
dynamic_size=False, n_static = tensor_shape.Dimension(
infer_shape=True) tensor_shape.dimension_value(
for elem in elems_flat] elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
for tensor in elems_batchable[1:]:
n_static.merge_with(
tensor_shape.Dimension(
tensor_shape.dimension_value(
tensor.get_shape().with_rank_at_least(1)[0])))
n = n_static.value or array_ops.shape(elems_batchable[0])[0]
# Convert elems to tensor array.
# TODO(edloper): Should we set infer_shape=False for composite tensors?
elems_batchable_ta = [
tensor_array_ops.TensorArray(
dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
for t in elems_batchable
]
# Unpack elements # Unpack elements
elems_ta = [ elems_batchable_ta = [
elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
]
i = constant_op.constant(0) i = constant_op.constant(0)
accs_ta = [ # Prepare result tensor array.
tensor_array_ops.TensorArray(dtype=dt, # TODO(edloper): Should we set infer_shape=False for composite tensors?
size=n, result_batchable_dtype = _result_flat_signature_to_batchable_dtype(
dynamic_size=False, result_flat_signature)
infer_shape=infer_shape) result_batchable_ta = [
for dt in dtype_flat] tensor_array_ops.TensorArray(
dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape)
for dt in result_batchable_dtype
]
def compute(i, tas): def compute(i, tas):
"""The loop body of map_fn. """The loop body of map_fn.
@ -252,30 +463,34 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
(i + 1, tas): the updated counter + updated TensorArrays (i + 1, tas): the updated counter + updated TensorArrays
Raises: Raises:
TypeError: if dtype and packed_fn_values structure do not match TypeError: if fn_output_signature and result_value structure don't match
ValueType: if dtype and packed_fn_values lengths do not match ValueType: if fn_output_signature and result_value lengths don't match
""" """
packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
packed_fn_values = fn(packed_values) elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
nest.assert_same_structure(dtype or elems, packed_fn_values) elems_flat_signature)
flat_fn_values = output_flatten(packed_fn_values) elems_value = elems_unflatten(elems_value_flat)
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)] result_value = fn(elems_value)
nest.assert_same_structure(fn_output_signature or elems, result_value)
result_value_flat = nest.flatten(result_value)
result_value_batchable = _result_value_flat_to_batchable(
result_value_flat, result_flat_signature)
tas = [
ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
]
return (i + 1, tas) return (i + 1, tas)
_, r_a = control_flow_ops.while_loop( _, r_a = control_flow_ops.while_loop(
lambda i, _: i < n, compute, (i, accs_ta), lambda i, _: i < n,
compute, (i, result_batchable_ta),
parallel_iterations=parallel_iterations, parallel_iterations=parallel_iterations,
back_prop=back_prop, back_prop=back_prop,
swap_memory=swap_memory, swap_memory=swap_memory,
maximum_iterations=n) maximum_iterations=n)
results_flat = [r.stack() for r in r_a] result_batchable = [r.stack() for r in r_a]
n_static = tensor_shape.Dimension(tensor_shape.dimension_value( # Update each output tensor w/ static shape info about the outer dimension.
elems_flat[0].get_shape().with_rank_at_least(1)[0])) for r in result_batchable:
for elem in elems_flat[1:]:
n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value(
elem.get_shape().with_rank_at_least(1)[0])))
for r in results_flat:
r.set_shape(tensor_shape.TensorShape(n_static).concatenate( r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
r.get_shape()[1:])) r.get_shape()[1:]))
@ -284,7 +499,103 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
if in_graph_mode and varscope_caching_device_was_none: if in_graph_mode and varscope_caching_device_was_none:
varscope.set_caching_device(None) varscope.set_caching_device(None)
return output_pack(results_flat) result_flat = _result_batchable_to_flat(result_batchable,
result_flat_signature)
result = result_unflatten(result_flat)
return result
def _dtype_to_spec(d):
if not isinstance(d, type_spec.TypeSpec):
d = tensor_spec.TensorSpec(None, d)
return d
def _most_general_compatible_type(spec):
"""Returns the most general TypeSpec compatible with `spec`."""
# TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
if isinstance(spec, tensor_spec.TensorSpec):
return tensor_spec.TensorSpec(None, spec.dtype)
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
# pylint: disable=protected-access
return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
spec._row_splits_dtype)
elif isinstance(spec, sparse_tensor.SparseTensorSpec):
# pylint: disable=protected-access
return sparse_tensor.SparseTensorSpec(None, spec.dtype)
else:
return spec
def _result_flat_signature_to_batchable_dtype(result_flat_signature):
"""Converts result_flat_signature -> result_batchable_dtype."""
components = []
for spec in result_flat_signature:
if not isinstance(spec, type_spec.BatchableTypeSpec):
raise TypeError("map_fn can not generate %s outputs" % (spec,))
# pylint: disable=protected-access
components.extend([s.dtype for s in spec._flat_tensor_specs])
return components
def _elems_flat_to_batchable(elems_flat):
"""Converts elems_flat -> elems_batchable."""
elems_batchable = []
for elems_tensor in elems_flat:
spec = type_spec.type_spec_from_value(elems_tensor)
if not isinstance(spec, type_spec.BatchableTypeSpec):
raise TypeError("map_fn can not consume %s inputs: got %r" %
(spec, elems_tensor))
# pylint: disable=protected-access
elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor))
return elems_batchable
def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature):
"""Converts elems_value_batchable -> elems_value_flat."""
elems_value_flat = []
i = 0
for spec in elems_flat_signature:
# pylint: disable=protected-access
spec = spec._unbatch()
tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)]
elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list))
i += len(tensor_list)
assert i == len(elems_value_batchable)
return elems_value_flat
def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
"""Converts result_value_flat -> result_value_batchable."""
result_value_batchable = []
for (r_value, r_spec) in zip(result_value_flat, result_flat_signature):
if isinstance(r_spec, tensor_spec.TensorSpec):
result_value_batchable.append(r_value)
else:
if not r_spec.is_compatible_with(r_value):
raise ValueError(
"Error in map_fn:\n Expected `fn` to return a:\n %s\n"
" But it returned a:\n %s\n (value=%s)\n"
" To fix, update the `fn_output_signature` (or `dtype`) "
"argument to `map_fn`." %
(r_spec, type_spec.type_spec_from_value(r_value), r_value))
result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access
return result_value_batchable
def _result_batchable_to_flat(result_batchable, result_flat_signature):
"""Converts result_batchable -> result_flat."""
result_flat = []
i = 0
for spec in result_flat_signature:
# pylint: disable=protected-access
num_tensors = len(spec._flat_tensor_specs)
result_flat.append(
spec._batch(None)._from_compatible_tensor_list(
result_batchable[i:i + num_tensors]))
i += num_tensors
assert i == len(result_batchable)
return result_flat
@tf_export("map_fn", v1=[]) @tf_export("map_fn", v1=[])
@ -297,6 +608,7 @@ Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""", results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""",
warn_once=True, warn_once=True,
back_prop=False) back_prop=False)
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
def map_fn_v2(fn, def map_fn_v2(fn,
elems, elems,
dtype=None, dtype=None,
@ -304,122 +616,25 @@ def map_fn_v2(fn,
back_prop=True, back_prop=True,
swap_memory=False, swap_memory=False,
infer_shape=True, infer_shape=True,
name=None): name=None,
"""map on the list of tensors unpacked from `elems` on dimension 0. fn_output_signature=None):
"""Transform `elems` by applying `fn` to each element unstacked on axis 0."""
The simplest version of `map_fn` repeatedly applies the callable `fn` to a if fn_output_signature is None:
sequence of elements from first to last. The elements are made of the fn_output_signature = dtype
tensors unpacked from `elems`. `dtype` is the data type of the return
value of `fn`. Users must provide `dtype` if it is different from
the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.
This method also allows multi-arity `elems` and output of `fn`. If `elems`
is a (possibly nested) list or tuple of tensors, then each of these tensors
must have a matching first (unpack) dimension. The signature of `fn` may
match the structure of `elems`. That is, if `elems` is
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
Furthermore, `fn` may emit a different structure than its input. For example,
`fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case,
the `dtype` parameter is not optional: `dtype` must be a type or (possibly
nested) tuple of types matching the output of `fn`.
To apply a functional operation to the nonzero elements of a SparseTensor
one of the following methods is recommended. First, if the function is
expressible as TensorFlow ops, use
```python
result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
```
If, however, the function is not expressible as a TensorFlow op, then use
```python
result = SparseTensor(
input.indices, map_fn(fn, input.values), input.dense_shape)
```
instead.
When executing eagerly, map_fn does not execute in parallel even if
`parallel_iterations` is set to a value > 1. You can still get the
performance benefits of running a function in parallel by using the
`tf.function` decorator,
```python
# Assume the function being used in map_fn is fn.
# To ensure map_fn calls fn in parallel, use the tf.function decorator.
@tf.function
def func(tensor):
return tf.map_fn(fn, tensor)
```
Note that if you use the `tf.function` decorator, any non-TensorFlow Python
code that you may have written in your function won't get executed. See
[`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for
more details. The recommendation would be to debug without `tf.function` but
switch to it to get performance benefits of running `map_fn` in parallel.
Args:
fn: The callable to be performed. It accepts one argument, which will have
the same (possibly nested) structure as `elems`. Its output must have the
same structure as `dtype` if one is provided, otherwise it must have the
same structure as `elems`.
elems: A tensor or (possibly nested) sequence of tensors, each of which will
be unpacked along their first dimension. The nested sequence of the
resulting slices will be applied to `fn`.
dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure
of Tensors differing from the structure of `elems`, then `dtype` is not
optional and must have the same structure as the output of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run in
parallel. When graph building, the default value is 10. While executing
eagerly, the default value is set to 1.
back_prop: (optional) Deprecated. False disables support for back
propagation. Prefer using `tf.stop_gradient` instead.
swap_memory: (optional) True enables GPU-CPU memory swapping.
infer_shape: (optional) False disables tests for consistent output shapes.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
results of applying `fn` to tensors unpacked from `elems` along the first
dimension, from first to last.
Raises:
TypeError: if `fn` is not callable or the structure of the output of
`fn` and `dtype` do not match, or if elems is a SparseTensor.
ValueError: if the lengths of the output of `fn` and `dtype` do not match.
Examples:
```python
elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
```python
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]
```
```python
elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]
```
"""
return map_fn( return map_fn(
fn=fn, fn=fn,
elems=elems, elems=elems,
dtype=dtype, fn_output_signature=fn_output_signature,
parallel_iterations=parallel_iterations, parallel_iterations=parallel_iterations,
back_prop=back_prop, back_prop=back_prop,
swap_memory=swap_memory, swap_memory=swap_memory,
infer_shape=infer_shape, infer_shape=infer_shape,
name=name) name=name)
# Docstring for v2 is the same as v1, except that back_prop is deprecated.
map_fn_v2.__doc__ = re.sub(
r"( back_prop: \(optional\) )(.*)",
r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2",
map_fn.__doc__)
assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__

View File

@ -46,13 +46,16 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
dict( dict(
fn=mo.reduce_mean, fn=mo.reduce_mean,
elems=[[1, 2, 3], [4, 5], [6, 7]], elems=[[1, 2, 3], [4, 5], [6, 7]],
elems_dtype=dtypes.int32,
expected_output=[2, 4, 6], expected_output=[2, 4, 6],
result_dtype=dtypes.int32,
), ),
dict( dict(
fn=string_ops.reduce_join, fn=string_ops.reduce_join,
elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']], elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']],
expected_output=[b'foobarbaz', b'a', b'bc'], expected_output=[b'foobarbaz', b'a', b'bc'],
dtype=dtypes.string, elems_dtype=dtypes.string,
result_dtype=dtypes.string,
), ),
# [d1, (d2)] -> [d1, 2] # [d1, (d2)] -> [d1, 2]
dict( dict(
@ -60,7 +63,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
# fn=self.stack_mean_and_sum, # fn=self.stack_mean_and_sum,
elems=[[1, 2, 3], [4, 5], [6, 7]], elems=[[1, 2, 3], [4, 5], [6, 7]],
expected_output=[[2, 6], [4.5, 9], [6.5, 13]], expected_output=[[2, 6], [4.5, 9], [6.5, 13]],
dtype=dtypes.float32, elems_dtype=dtypes.float32,
result_dtype=dtypes.float32,
expected_ragged_rank=0, expected_ragged_rank=0,
), ),
# [d1, (d2)] -> [d1, (d2)] # [d1, (d2)] -> [d1, (d2)]
@ -68,7 +72,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
fn=lambda x: x + np.int64(1), fn=lambda x: x + np.int64(1),
elems=[[1, 2, 3], [4, 5], [6, 7]], elems=[[1, 2, 3], [4, 5], [6, 7]],
expected_output=[[2, 3, 4], [5, 6], [7, 8]], expected_output=[[2, 3, 4], [5, 6], [7, 8]],
dtype=dtypes.int64, elems_dtype=dtypes.int64,
result_dtype=ragged_tensor.RaggedTensorType( result_dtype=ragged_tensor.RaggedTensorType(
dtype=dtypes.int64, ragged_rank=1), dtype=dtypes.int64, ragged_rank=1),
), ),
@ -157,11 +161,11 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
expected_ragged_rank=None, expected_ragged_rank=None,
result_ragged_rank=None, result_ragged_rank=None,
elems_ragged_rank=None, elems_ragged_rank=None,
dtype=dtypes.int64, elems_dtype=dtypes.int64,
result_dtype=None, result_dtype=None,
infer_shape=False, infer_shape=True,
): ):
elems = ragged_factory_ops.constant(elems, dtype, elems_ragged_rank) elems = ragged_factory_ops.constant(elems, elems_dtype, elems_ragged_rank)
output = ragged_map_ops.map_fn( output = ragged_map_ops.map_fn(
fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape) fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape)
@ -260,8 +264,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
def testMismatchRaggedRank(self): def testMismatchRaggedRank(self):
elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
with self.assertRaisesWithLiteralMatch( with self.assertRaisesRegexp(
ValueError, r'The declared ragged rank (23) mismatches the result (1)'): ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
_ = ragged_map_ops.map_fn( _ = ragged_map_ops.map_fn(
fn, fn,
elems, elems,
@ -271,8 +275,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
def testMismatchRaggedRank2(self): def testMismatchRaggedRank2(self):
elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]]) elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]) fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
with self.assertRaisesWithLiteralMatch( with self.assertRaisesRegexp(
ValueError, r'The declared ragged rank (10) mismatches the result (2)'): ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
_ = ragged_map_ops.map_fn( _ = ragged_map_ops.map_fn(
fn, fn,
elems, elems,

View File

@ -17,22 +17,14 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.ragged import ragged_config
from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader
map_fn_lib = LazyLoader(
"map_fn_lib", globals(),
"tensorflow.python.ops.map_fn")
def map_fn(fn, def map_fn(fn,
@ -166,298 +158,25 @@ def map_fn(fn,
# out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
``` ```
""" """
if not callable(fn): if dtype is None:
raise TypeError("fn must be callable.") dtype = nest.map_structure(lambda e: e.dtype, elems)
dtype = nest.map_structure(_ragged_type_to_spec, dtype)
if isinstance(elems, sparse_tensor.SparseTensor): return map_fn_lib.map_fn(fn,
raise TypeError( elems,
"To perform a map on the values of a sparse tensor use either " dtype,
" SparseTensor(input.indices, fn(input.values), input.dense_shape) or " parallel_iterations,
" SparseTensor(input.indices, map_fn(fn, input.values), " back_prop,
"input.dense_shape)") swap_memory,
infer_shape,
in_graph_mode = not context.executing_eagerly() name)
# Set the default number of parallel_iterations depending on graph/eager mode.
if in_graph_mode and not parallel_iterations:
parallel_iterations = 10
elif not in_graph_mode and not parallel_iterations:
parallel_iterations = 1
if not in_graph_mode and parallel_iterations > 1:
logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
"effect when executing eagerly. Consider calling map_fn"
" with tf.contrib.eager.defun to execute fn in "
"parallel.", 1)
parallel_iterations = 1
input_is_sequence = nest.is_sequence(elems)
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
def input_pack(x):
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
elems_flat = input_flatten(elems)
elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)
with ops.name_scope(name, "map", elems_flat):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode:
# Any get_variable calls in fn will cache the first call locally
# and not issue repeated network I/O requests for each iteration.
varscope = vs.get_variable_scope()
varscope_caching_device_was_none = False
if varscope.caching_device is None:
# TODO(ebrevdo): Change to using colocate_with here and in other
# methods.
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
elems_flat = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem")
for elem in elems_flat
]
# We can either infer the output, or we can assume that it will be the same
# as the input structure.
dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
# Find the number of iterations, n may be known statically.
if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
n = elems_flat[0].nrows(out_type=dtypes.int32)
else:
static_shape = elems_flat[0].shape
if static_shape.ndims is not None and static_shape.ndims < 1:
if len(elems_flat) == 1:
raise ValueError(
"elems must be a 1+ dimensional Tensor, not a scalar")
else:
raise ValueError(
"elements in elems must be 1+ dimensional Tensors, not scalars")
n = (tensor_shape.dimension_value(static_shape[0]) or
array_ops.shape(elems_flat[0])[0])
n = math_ops.cast(n, dtype=dtypes.int32)
# Create a flat list of TAs.
# Flatten the dtype structure to a list.
dtype_flat = nest.flatten(dtype)
# decompose to components
dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
dtype_components_flat = nest.flatten(dtype_components)
# Create TensorArrays.
accs_ta = [
tensor_array_ops.TensorArray(
dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n)
for t in dtype_components_flat
]
i = constant_op.constant(0, dtype=dtypes.int32)
def compute(i, tas):
"""The loop body of map_fn.
Args:
i: the loop counter
tas: the flat TensorArray accumulator list
Returns:
(i + 1, tas): the updated counter + updated TensorArrays
Raises:
TypeError: if dtype and packed_fn_values structure do not match
ValueType: if dtype and packed_fn_values lengths do not match
"""
# Get Tensors or RaggedTensors sliced at i, then pack it back to the
# original structure.
packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat])
packed_fn_values = fn(packed_values)
# Check that the structure of the output matches what was declared or
# inferred.
# nest.assert_same_structure(dtype or elems, packed_fn_values)
# Flatten and decompose to a list of Tensors
flat_fn_values = nest.flatten(packed_fn_values)
# If we declared that we are expecting a RaggedTensor output, but we get a
# Tensor output. We should try to convert it to a RaggedTensor.
flat_fn_composite_tensors = list(
_convert_declared(flat_fn_values, dtype_flat))
flat_fn_components = [
_maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
]
flat_fn_tensors = nest.flatten(flat_fn_components)
# Write to TAs.
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)]
return (i + 1, tas)
_, r_a = control_flow_ops.while_loop(
lambda i, _: i < n, compute, (i, accs_ta),
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory,
maximum_iterations=n)
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
varscope.set_caching_device(None)
# Pack back into a list of components
results_as_components = nest.pack_sequence_as(dtype_components, r_a)
# Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
def _stack_or_concat(e):
if isinstance(e, _RaggedTensorComponents):
return _concat_ragged_tensor_components(e)
else:
result = e.stack()
return result
results_flat_components = [
_stack_or_concat(e) for e in results_as_components
]
results_packed = [
_maybe_recompose_tensor(c) for c in results_flat_components
]
results_packed = nest.pack_sequence_as(dtype, results_packed)
return results_packed
class _RaggedTensorComponents( def _ragged_type_to_spec(t):
collections.namedtuple( if isinstance(t, ragged_tensor.RaggedTensorType):
"_RaggedTensorComponents", # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the
["flat_values", "nested_row_lengths", "outer_row_length"])): # type for the mapped `fn` output, but RaggedTensorType gives the type for
"""A namedtuple of components which represent a `RaggedTensor`. # the result of stacking the mapped `fn` outputs.
return ragged_tensor.RaggedTensorSpec(
_RaggedTensorComponents is a list of components which can be used to create a None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype)
`RaggedTensor`. Use this class to represent a `RaggedTensor` in situations
where nest.flatten and nest.pack_sequence_as should decompose ragged tensors
into their components..
The following are a list of components for a `RaggedTensor`:
flat_values: The flat and inner values of a RaggedTensor. This could be
a `Tensor`, a `TensorArray`, or a data type.
nested_row_lengths: a tuple containing the row lengths of each rank. The
elements of the tuple could be `Tensor`s or `TensorArray`s.
outer_row_length: a `Tensor` or `TensorArray` containing the row length of the
`RaggedTensor`'s outermost dimension.
See `RaggedTensor` for more details of the use of each component.
"""
__slots__ = ()
def _concat_ragged_tensor_components(rt_ta):
flat_values = rt_ta.flat_values.concat()
nested_row_lengths = tuple(
row_lengths_ta.concat() for row_lengths_ta in rt_ta.nested_row_lengths)
outer_row_length = rt_ta.outer_row_length.concat()
return _RaggedTensorComponents(
flat_values=flat_values,
nested_row_lengths=nested_row_lengths,
outer_row_length=outer_row_length)
def _maybe_decompose_tensor(rt):
"""Decompose tensors to their composite tensors."""
if not isinstance(rt, ragged_tensor.RaggedTensor):
return rt
# The three component pieces we need:
# - inner values
flat_values = rt.flat_values
# - row_splits of the RT
splits = rt.nested_row_splits
nested_row_lengths = tuple(split[1:] - split[:-1] for split in splits)
# - outer row length
outer_row_length = array_ops.expand_dims(rt.nrows(), axis=0)
return _RaggedTensorComponents(
flat_values=flat_values,
nested_row_lengths=nested_row_lengths,
outer_row_length=outer_row_length,
)
def _maybe_recompose_tensor(t):
"""Reconstructs a _RaggedTensorComponents into a RaggedTensor."""
if not isinstance(t, _RaggedTensorComponents):
return t
values = t.flat_values
nested_row_lengths = tuple(t.nested_row_lengths)
for nested_row_length in reversed(nested_row_lengths):
values = ragged_tensor.RaggedTensor.from_row_lengths(
values, nested_row_length, validate=False)
return ragged_tensor.RaggedTensor.from_row_lengths(values, t.outer_row_length,
validate=False)
def _maybe_decompose_dtype(d):
"""Decompose dtypes into composite tensors (if necessary)."""
if not isinstance(d, ragged_tensor.RaggedTensorType):
return d
result = _RaggedTensorComponents(
flat_values=d.dtype,
nested_row_lengths=tuple(
d.row_splits_dtype for i in range(d.ragged_rank - 1)),
outer_row_length=d.row_splits_dtype,
)
return result
def _convert_declared(fn_output_flat, output_declared):
"""Convert outputs which are `Tensor`s into `_RaggedTensorComponents`."""
for current, declared in zip(fn_output_flat, output_declared):
if isinstance(declared, ragged_tensor.RaggedTensorType):
yield _convert_declared_ragged(current, declared)
else:
yield current
def _convert_declared_ragged(current, declared):
"""Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
# Check that the ragged ranks match up.
# + 1 to account for the rank of the outermost dimension.
current_ragged_rank = getattr(current, "ragged_rank", 0)
if declared.ragged_rank != current_ragged_rank + 1:
raise ValueError(
"The declared ragged rank (%d) mismatches the result (%d)" %
(declared.ragged_rank, current_ragged_rank + 1))
# Check that dtypes match up.
if declared.dtype != current.dtype:
raise ValueError(
"The declared dtype (%s) mismatches the result (%s)" %
(declared.dtype, current.dtype))
if (isinstance(current, ragged_tensor.RaggedTensor) and
declared.row_splits_dtype != current.row_splits.dtype):
if not ragged_config.auto_cast_partition_dtype():
raise ValueError(
"The declared row_splits dtype (%s) mismatches the result (%s)."
" Use RaggedTensor.with_row_splits_dtype to convert it."
% (declared.row_splits_dtype, current.row_splits.dtype))
current = current.with_row_splits_dtype(declared.row_splits_dtype)
if isinstance(current, ragged_tensor.RaggedTensor):
return current
else: else:
nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0] return t
row_length = array_ops.expand_dims(nrows, axis=0)
return _RaggedTensorComponents(
flat_values=current,
nested_row_lengths=(),
outer_row_length=row_length)

View File

@ -2330,6 +2330,10 @@ class RaggedTensorType(object):
ragged_rank = property(lambda self: self._ragged_rank) ragged_rank = property(lambda self: self._ragged_rank)
row_splits_dtype = property(lambda self: self._row_splits_dtype) row_splits_dtype = property(lambda self: self._row_splits_dtype)
def __repr__(self):
return "RaggedTensorType(%r, %r, %r)" % (
self.dtype, self.ragged_rank, self.row_splits_dtype)
#=============================================================================== #===============================================================================
# Helper Functions # Helper Functions

View File

@ -1626,7 +1626,7 @@ tf_module {
} }
member_method { member_method {
name: "map_fn" name: "map_fn"
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], " argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "matching_files" name: "matching_files"

View File

@ -786,7 +786,7 @@ tf_module {
} }
member_method { member_method {
name: "map_fn" name: "map_fn"
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], " argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], "
} }
member_method { member_method {
name: "matmul" name: "matmul"