Update tf.map_fn to support RaggedTensors and SparseTensors.
PiperOrigin-RevId: 300134914 Change-Id: I84e75ea1d71257d34d045b82b54de208ea2fe75b
This commit is contained in:
parent
3299b778c3
commit
0bbfe8d95d
tensorflow
python
kernel_tests
ops
tools/api/golden
@ -69,13 +69,14 @@ class MapFnTest(test.TestCase):
|
||||
|
||||
def testMapSparseTensor(self):
|
||||
with self.cached_session():
|
||||
with self.assertRaises(TypeError):
|
||||
map_fn.map_fn(
|
||||
lambda x: x,
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0]],
|
||||
values=constant_op.constant([0, 1, 2]),
|
||||
dense_shape=[2, 2]))
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [0, 1], [1, 0]],
|
||||
values=constant_op.constant([0, 1, 2]),
|
||||
dense_shape=[2, 2])
|
||||
result = map_fn.map_fn(lambda x: x, st)
|
||||
self.assertAllEqual(result.indices, st.indices)
|
||||
self.assertAllEqual(result.values, st.values)
|
||||
self.assertAllEqual(result.dense_shape, st.dense_shape)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMapOverScalarErrors(self):
|
||||
|
@ -20,15 +20,20 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
@ -36,134 +41,324 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export(v1=["map_fn"])
|
||||
def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
|
||||
swap_memory=False, infer_shape=True, name=None):
|
||||
"""map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
|
||||
def map_fn(fn,
|
||||
elems,
|
||||
dtype=None,
|
||||
parallel_iterations=None,
|
||||
back_prop=True,
|
||||
swap_memory=False,
|
||||
infer_shape=True,
|
||||
name=None,
|
||||
fn_output_signature=None):
|
||||
"""Transforms `elems` by applying `fn` to each element unstacked on axis 0.
|
||||
|
||||
The simplest version of `map_fn` repeatedly applies the callable `fn` to a
|
||||
sequence of elements from first to last. The elements are made of the
|
||||
tensors unpacked from `elems`. `dtype` is the data type of the return
|
||||
value of `fn`. Users must provide `dtype` if it is different from
|
||||
the data type of `elems`.
|
||||
`map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements;
|
||||
calls `fn` to transform each element; and then stacks the transformed
|
||||
values back together.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.
|
||||
#### Mapping functions with single-Tensor inputs and outputs
|
||||
|
||||
This method also allows multi-arity `elems` and output of `fn`. If `elems`
|
||||
is a (possibly nested) list or tuple of tensors, then each of these tensors
|
||||
must have a matching first (unpack) dimension. The signature of `fn` may
|
||||
match the structure of `elems`. That is, if `elems` is
|
||||
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
|
||||
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
|
||||
If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`,
|
||||
then `map_fn(fn, elems)` is equivalent to
|
||||
`tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.:
|
||||
|
||||
Furthermore, `fn` may emit a different structure than its input. For example,
|
||||
`fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case,
|
||||
the `dtype` parameter is not optional: `dtype` must be a type or (possibly
|
||||
nested) tuple of types matching the output of `fn`.
|
||||
>>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
|
||||
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
|
||||
array([[3, 4, 5],
|
||||
[5, 6, 7],
|
||||
[2, 3, 4]], dtype=int32)>
|
||||
|
||||
To apply a functional operation to the nonzero elements of a SparseTensor
|
||||
one of the following methods is recommended. First, if the function is
|
||||
expressible as TensorFlow ops, use
|
||||
`map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`.
|
||||
|
||||
```python
|
||||
result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
|
||||
```
|
||||
#### Mapping functions with multi-arity inputs and outputs
|
||||
|
||||
If, however, the function is not expressible as a TensorFlow op, then use
|
||||
`map_fn` also supports functions with multi-arity inputs and outputs:
|
||||
|
||||
```python
|
||||
result = SparseTensor(
|
||||
input.indices, map_fn(fn, input.values), input.dense_shape)
|
||||
```
|
||||
* If `elems` is a tuple (or nested structure) of tensors, then those tensors
|
||||
must all have the same outer-dimension size (`num_elems`); and `fn` is
|
||||
used to transform each tuple (or structure) of corresponding slices from
|
||||
`elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to
|
||||
transform each tuple of slices `(t1[i], t2[i], t3[i])`
|
||||
(where `0 <= i < num_elems`).
|
||||
|
||||
instead.
|
||||
* If `fn` returns a tuple (or nested structure) of tensors, then the
|
||||
result is formed by stacking corresponding elements from those structures.
|
||||
|
||||
When executing eagerly, map_fn does not execute in parallel even if
|
||||
#### Specifying `fn`'s output signature
|
||||
|
||||
If `fn`'s input and output signatures are different, then the output
|
||||
signature must be specified using `fn_output_signature`. (The input and
|
||||
output signatures are differ if their structures, dtypes, or tensor types do
|
||||
not match). E.g.:
|
||||
|
||||
>>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes
|
||||
... elems=tf.constant(["hello", "moon"]),
|
||||
... fn_output_signature=tf.int32)
|
||||
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
|
||||
>>> tf.map_fn(fn=tf.strings.join, # input & output have different structures
|
||||
... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
|
||||
... fn_output_signature=tf.string)
|
||||
<tf.Tensor: shape=(2,), dtype=string,
|
||||
numpy=array([b'TheDog', b'ACat'], dtype=object)>
|
||||
|
||||
`fn_output_signature` can be specified using any of the following:
|
||||
|
||||
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
|
||||
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
|
||||
* A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
|
||||
* A (possibly nested) tuple, list, or dict containing the above types.
|
||||
|
||||
#### RaggedTensors
|
||||
|
||||
`map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular:
|
||||
|
||||
* If `elems` is a `RaggedTensor`, then `fn` will be called with each
|
||||
row of that ragged tensor.
|
||||
|
||||
* If `elems` has only one ragged dimension, then the values passed to
|
||||
`fn` will be `tf.Tensor`s.
|
||||
* If `elems` has multiple ragged dimensions, then the values passed to
|
||||
`fn` will be `tf.RaggedTensor`s with one fewer ragged dimension.
|
||||
|
||||
* If the result of `map_fn` should be a `RaggedTensor`, then use a
|
||||
`tf.RaggedTensorSpec` to specify `fn_output_signature`.
|
||||
|
||||
* If `fn` returns `tf.Tensor`s with varying sizes, then use a
|
||||
`tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a
|
||||
single ragged tensor (which will have ragged_rank=1).
|
||||
* If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec`
|
||||
with the same `ragged_rank`.
|
||||
|
||||
>>> # Example: RaggedTensor input
|
||||
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
|
||||
>>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
|
||||
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
|
||||
|
||||
>>> # Example: RaggedTensor output
|
||||
>>> elems = tf.constant([3, 5, 0, 2])
|
||||
>>> tf.map_fn(tf.range, elems,
|
||||
... fn_output_signature=tf.RaggedTensorSpec(shape=[None],
|
||||
... dtype=tf.int32))
|
||||
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>
|
||||
|
||||
Note: `map_fn` should only be used if you need to map a function over the
|
||||
*rows* of a `RaggedTensor`. If you wish to map a function over the
|
||||
individual values, then you should use:
|
||||
|
||||
* `tf.ragged.map_flat_values(fn, rt)`
|
||||
(if fn is expressible as TensorFlow ops)
|
||||
* `rt.with_flat_values(map_fn(fn, rt.flat_values))`
|
||||
(otherwise)
|
||||
|
||||
E.g.:
|
||||
|
||||
>>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
|
||||
>>> tf.ragged.map_flat_values(lambda x: x + 2, rt)
|
||||
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>
|
||||
|
||||
#### SparseTensors
|
||||
|
||||
`map_fn` supports `tf.SparseTensor` inputs and outputs. In particular:
|
||||
|
||||
* If `elems` is a `SparseTensor`, then `fn` will be called with each row
|
||||
of that sparse tensor. In particular, the value passed to `fn` will be a
|
||||
`tf.SparseTensor` with one fewer dimension than `elems`.
|
||||
|
||||
* If the result of `map_fn` should be a `SparseTensor`, then use a
|
||||
`tf.SparseTensorSpec` to specify `fn_output_signature`. The individual
|
||||
`SparseTensor`s returned by `fn` will be stacked into a single
|
||||
`SparseTensor` with one more dimension.
|
||||
|
||||
>>> # Example: SparseTensor input
|
||||
>>> st = tf.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
|
||||
>>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
|
||||
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
|
||||
|
||||
>>> # Example: SparseTensor output
|
||||
>>> tf.sparse.to_dense(
|
||||
... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
|
||||
... fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
|
||||
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
|
||||
array([[[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 0.]],
|
||||
[[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]]], dtype=float32)>
|
||||
|
||||
Note: `map_fn` should only be used if you need to map a function over the
|
||||
*rows* of a `SparseTensor`. If you wish to map a function over the nonzero
|
||||
values, then you should use:
|
||||
|
||||
* `tf.SparseTensor(st.indices, fn(st.values), st.dense_shape)`
|
||||
(if the function is expressible as TensorFlow ops)
|
||||
* `tf.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)`
|
||||
(otherwise).
|
||||
|
||||
#### `map_fn` vs. vectorized operations
|
||||
|
||||
`map_fn` will apply the operations used by `fn` to each element of `elems`,
|
||||
resulting in `O(elems.shape[0])` total operations. This is somewhat
|
||||
mitigated by the fact that `map_fn` can process elements in parallel.
|
||||
However, a transform expressed using `map_fn` is still typically less
|
||||
efficient than an equivalent transform expressed using vectorized operations.
|
||||
|
||||
`map_fn` should typically only be used if one of the following is true:
|
||||
|
||||
* It is difficult or expensive to express the desired transform with
|
||||
vectorized operations.
|
||||
* `fn` creates large intermediate values, so an equivalent vectorized
|
||||
transform would take too much memory.
|
||||
* Processing elements in parallel is more efficient than an equivalent
|
||||
vectorized transform.
|
||||
* Efficiency of the transform is not critical, and using `map_fn` is
|
||||
more readable.
|
||||
|
||||
E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)`
|
||||
across `elems` could be rewritten more efficiently using vectorized ops:
|
||||
|
||||
>>> elems = tf.constant([3, 5, 2])
|
||||
>>> tf.range(3) + tf.expand_dims(elems, 1)
|
||||
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
|
||||
array([[3, 4, 5],
|
||||
[5, 6, 7],
|
||||
[2, 3, 4]], dtype=int32)>
|
||||
|
||||
In some cases, `tf.vectorized_map` can be used to automatically convert a
|
||||
function to a vectorized eqivalent.
|
||||
|
||||
#### Eager execution
|
||||
|
||||
When executing eagerly, `map_fn` does not execute in parallel even if
|
||||
`parallel_iterations` is set to a value > 1. You can still get the
|
||||
performance benefits of running a function in parallel by using the
|
||||
`tf.function` decorator,
|
||||
`tf.function` decorator:
|
||||
|
||||
>>> fn=lambda t: tf.range(t, t + 3)
|
||||
>>> @tf.function
|
||||
... def func(elems):
|
||||
... return tf.map_fn(fn, elems, parallel_iterations=3)
|
||||
>>> func(tf.constant([3, 5, 2]))
|
||||
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
|
||||
array([[3, 4, 5],
|
||||
[5, 6, 7],
|
||||
[2, 3, 4]], dtype=int32)>
|
||||
|
||||
```python
|
||||
# Assume the function being used in map_fn is fn.
|
||||
# To ensure map_fn calls fn in parallel, use the tf.function decorator.
|
||||
@tf.function
|
||||
def func(tensor):
|
||||
return tf.map_fn(fn, tensor)
|
||||
```
|
||||
|
||||
Note that if you use the `tf.function` decorator, any non-TensorFlow Python
|
||||
code that you may have written in your function won't get executed. See
|
||||
[`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for
|
||||
more details. The recommendation would be to debug without `tf.function` but
|
||||
switch to it to get performance benefits of running `map_fn` in parallel.
|
||||
`tf.function` for more details. The recommendation would be to debug without
|
||||
`tf.function` but switch to it to get performance benefits of running `map_fn`
|
||||
in parallel.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed. It accepts one argument, which will
|
||||
have the same (possibly nested) structure as `elems`. Its output
|
||||
must have the same structure as `dtype` if one is provided, otherwise
|
||||
it must have the same structure as `elems`.
|
||||
elems: A tensor or (possibly nested) sequence of tensors, each of which
|
||||
will be unpacked along their first dimension. The nested sequence
|
||||
of the resulting slices will be applied to `fn`.
|
||||
dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure
|
||||
of Tensors differing from the structure of `elems`, then `dtype` is not
|
||||
optional and must have the same structure as the output of `fn`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel. When graph building, the default value is 10. While executing
|
||||
fn: The callable to be performed. It accepts one argument, which will have
|
||||
the same (possibly nested) structure as `elems`. Its output must have the
|
||||
same structure as `fn_output_signature` if one is provided; otherwise it
|
||||
must have the same structure as `elems`.
|
||||
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
||||
be unstacked along their first dimension. `fn` will be applied to the
|
||||
nested sequence of the resulting slices. `elems` may include ragged and
|
||||
sparse tensors.
|
||||
dtype: Deprecated: Equivalent to `fn_output_signature`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run in
|
||||
parallel. When graph building, the default value is 10. While executing
|
||||
eagerly, the default value is set to 1.
|
||||
back_prop: (optional) True enables support for back propagation.
|
||||
back_prop: (optional) False disables support for back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
infer_shape: (optional) False disables tests for consistent output shapes.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
fn_output_signature: The output signature of `fn`. Must be specified if
|
||||
`fn`'s input and output signatures are different (i.e., if their
|
||||
structures, dtypes, or tensor types do not match).
|
||||
`fn_output_signature` can be specified using any of the following:
|
||||
|
||||
* A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`)
|
||||
* A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`)
|
||||
* A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`)
|
||||
* A (possibly nested) tuple, list, or dict containing the above types.
|
||||
|
||||
Returns:
|
||||
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
|
||||
results of applying `fn` to tensors unpacked from `elems` along the first
|
||||
dimension, from first to last.
|
||||
A tensor or (possibly nested) sequence of tensors. Each tensor stacks the
|
||||
results of applying `fn` to tensors unstacked from `elems` along the first
|
||||
dimension, from first to last. The result may include ragged and sparse
|
||||
tensors.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable or the structure of the output of
|
||||
`fn` and `dtype` do not match, or if elems is a SparseTensor.
|
||||
ValueError: if the lengths of the output of `fn` and `dtype` do not match.
|
||||
`fn` and `fn_output_signature` do not match.
|
||||
ValueError: if the lengths of the output of `fn` and `fn_output_signature`
|
||||
do not match.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
elems = np.array([1, 2, 3, 4, 5, 6])
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
|
||||
```python
|
||||
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
|
||||
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
|
||||
# alternate == [-1, 2, -3]
|
||||
```
|
||||
>>> elems = np.array([1, 2, 3, 4, 5, 6])
|
||||
>>> tf.map_fn(lambda x: x * x, elems)
|
||||
<tf.Tensor: shape=(6,), dtype=int64, numpy=array([ 1, 4, 9, 16, 25, 36])>
|
||||
|
||||
```python
|
||||
elems = np.array([1, 2, 3])
|
||||
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
|
||||
# alternates[0] == [1, 2, 3]
|
||||
# alternates[1] == [-1, -2, -3]
|
||||
```
|
||||
>>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
|
||||
>>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64)
|
||||
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, 2, -3])>
|
||||
|
||||
>>> elems = np.array([1, 2, 3])
|
||||
>>> tf.map_fn(lambda x: (x, -x), elems,
|
||||
... fn_output_signature=(tf.int64, tf.int64))
|
||||
(<tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>,
|
||||
<tf.Tensor: shape=(3,), dtype=int64, numpy=array([-1, -2, -3])>)
|
||||
"""
|
||||
# This function uses a `while_loop` to call `fn` on each value of the input
|
||||
# tensor(s) (unstacked on dimension 0). The following sequence of variables
|
||||
# are used to transform the input tensor(s) (`elems`) into the output
|
||||
# tensor(s) (`result`):
|
||||
#
|
||||
# - Preparing and unstacking input values for the while_loop:
|
||||
# - elems: The input tensor(s) to map_fn. May include composite tensors.
|
||||
# - elems_flat: Flattened list of tensors from elems (using nest.flatten)
|
||||
# May include composite tensors.
|
||||
# - elems_batchable: Concatenation of "batchable tensor lists" for each
|
||||
# tensor in elems_flat. This "boxes" composite tensors
|
||||
# into sliceable tf.Tensor objects. For more info see:
|
||||
# TensorSpec._to_batched_tensor_list
|
||||
# - elems_batchable_ta: List of TensorArrays used to unstack each Tensor
|
||||
# in elems_batchable into elems_value_batchable.
|
||||
#
|
||||
# - Calling `fn` on each unstacked value in the body of the while_loop:
|
||||
# - elems_value_batchable: Single unstacked value from elems_batchable.
|
||||
# - elems_value_flat: Single unstacked value from elems_flat,
|
||||
# constructed from elems_value_batchable (using
|
||||
# TensorSpec._from_tensor_list).
|
||||
# - elems_value: Single unstacked value from elems (the input to fn).
|
||||
# - result_value: Result of calling `fn(elems_value)`. May contain
|
||||
# composite tensors.
|
||||
# - result_value_flat: Flattened list of tensors from result_value.
|
||||
# May contain composite tensors.
|
||||
# - result_value_batchable: Concatenation of batchable tensor lists for
|
||||
# each tensor in result_value_flat
|
||||
# (using TensorSpec._to_tensor_list).
|
||||
#
|
||||
# - Collecting and stacking output values from the while_loop:
|
||||
# - result_batchable_ta: List of TensorArrays used to stack each tensor
|
||||
# ta result_value_batchable into result_batchable.
|
||||
# - result_batchable: Stacked tensors from result_batchable_ta.
|
||||
# - result_flat: Flat list of tensors for the result, constructed from
|
||||
# results bactchable (using TensorSpec._from_tensor_list).
|
||||
# - result: Structured result value packed from results flat
|
||||
# (using nest.pack_sequence_as).
|
||||
|
||||
if fn_output_signature is None:
|
||||
fn_output_signature = dtype
|
||||
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
if isinstance(elems, sparse_tensor.SparseTensor):
|
||||
raise TypeError(
|
||||
"To perform a map on the values of a sparse tensor use either "
|
||||
" SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
|
||||
" SparseTensor(input.indices, map_fn(fn, input.values), "
|
||||
"input.dense_shape)")
|
||||
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
# Set the default number of parallel_iterations depending on graph/eager mode.
|
||||
if in_graph_mode and not parallel_iterations:
|
||||
parallel_iterations = 10
|
||||
elif not in_graph_mode and not parallel_iterations:
|
||||
parallel_iterations = 1
|
||||
|
||||
if not in_graph_mode and parallel_iterations > 1:
|
||||
elif not in_graph_mode and parallel_iterations > 1:
|
||||
logging.log_first_n(
|
||||
logging.WARN, "Setting parallel_iterations > 1 has no "
|
||||
"effect when executing eagerly. Consider calling map_fn"
|
||||
@ -171,23 +366,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
|
||||
"parallel.", 1)
|
||||
parallel_iterations = 1
|
||||
|
||||
input_is_sequence = nest.is_sequence(elems)
|
||||
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
|
||||
def input_pack(x):
|
||||
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
|
||||
# Flatten the input tensors, and get the TypeSpec for each one.
|
||||
elems_flat = nest.flatten(elems)
|
||||
elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat]
|
||||
elems_unflatten = lambda x: nest.pack_sequence_as(elems, x)
|
||||
|
||||
if dtype is None:
|
||||
output_is_sequence = input_is_sequence
|
||||
output_flatten = input_flatten
|
||||
output_pack = input_pack
|
||||
# Flatten fn's output signature.
|
||||
if fn_output_signature is None:
|
||||
# If fn_output_signature was not specified, then assume that it matches the
|
||||
# input signature.
|
||||
result_flat_signature = [
|
||||
_most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access
|
||||
for s in elems_flat_signature
|
||||
]
|
||||
result_unflatten = elems_unflatten
|
||||
else:
|
||||
output_is_sequence = nest.is_sequence(dtype)
|
||||
output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
|
||||
def output_pack(x):
|
||||
return (nest.pack_sequence_as(dtype, x)
|
||||
if output_is_sequence else x[0])
|
||||
|
||||
elems_flat = input_flatten(elems)
|
||||
result_flat_signature = [
|
||||
_dtype_to_spec(d) for d in nest.flatten(fn_output_signature)
|
||||
]
|
||||
result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x)
|
||||
|
||||
with ops.name_scope(name, "map", elems_flat):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
@ -204,42 +401,56 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
elems_flat = [
|
||||
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
|
||||
ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat
|
||||
]
|
||||
|
||||
dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
|
||||
dtype_flat = output_flatten(dtype)
|
||||
|
||||
# Convert elems to tensor array. n may be known statically.
|
||||
static_shape = elems_flat[0].shape
|
||||
if static_shape.ndims is not None and static_shape.ndims < 1:
|
||||
# Check that inputs are not scalars.
|
||||
elems_static_shape = elems_flat[0].shape
|
||||
if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1:
|
||||
if len(elems_flat) == 1:
|
||||
raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
|
||||
else:
|
||||
raise ValueError(
|
||||
"elements in elems must be 1+ dimensional Tensors, not scalars"
|
||||
)
|
||||
n = (tensor_shape.dimension_value(static_shape[0])
|
||||
or array_ops.shape(elems_flat[0])[0])
|
||||
|
||||
# TensorArrays are always flat
|
||||
elems_ta = [
|
||||
tensor_array_ops.TensorArray(dtype=elem.dtype,
|
||||
size=n,
|
||||
dynamic_size=False,
|
||||
infer_shape=True)
|
||||
for elem in elems_flat]
|
||||
# Box any composite tensors into tensor lists.
|
||||
elems_batchable = _elems_flat_to_batchable(elems_flat)
|
||||
|
||||
# Find the number of iterations, n. (may be known statically.)
|
||||
n_static = tensor_shape.Dimension(
|
||||
tensor_shape.dimension_value(
|
||||
elems_batchable[0].get_shape().with_rank_at_least(1)[0]))
|
||||
for tensor in elems_batchable[1:]:
|
||||
n_static.merge_with(
|
||||
tensor_shape.Dimension(
|
||||
tensor_shape.dimension_value(
|
||||
tensor.get_shape().with_rank_at_least(1)[0])))
|
||||
n = n_static.value or array_ops.shape(elems_batchable[0])[0]
|
||||
|
||||
# Convert elems to tensor array.
|
||||
# TODO(edloper): Should we set infer_shape=False for composite tensors?
|
||||
elems_batchable_ta = [
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True)
|
||||
for t in elems_batchable
|
||||
]
|
||||
# Unpack elements
|
||||
elems_ta = [
|
||||
elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]
|
||||
elems_batchable_ta = [
|
||||
ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable)
|
||||
]
|
||||
|
||||
i = constant_op.constant(0)
|
||||
|
||||
accs_ta = [
|
||||
tensor_array_ops.TensorArray(dtype=dt,
|
||||
size=n,
|
||||
dynamic_size=False,
|
||||
infer_shape=infer_shape)
|
||||
for dt in dtype_flat]
|
||||
# Prepare result tensor array.
|
||||
# TODO(edloper): Should we set infer_shape=False for composite tensors?
|
||||
result_batchable_dtype = _result_flat_signature_to_batchable_dtype(
|
||||
result_flat_signature)
|
||||
result_batchable_ta = [
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape)
|
||||
for dt in result_batchable_dtype
|
||||
]
|
||||
|
||||
def compute(i, tas):
|
||||
"""The loop body of map_fn.
|
||||
@ -252,30 +463,34 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
|
||||
(i + 1, tas): the updated counter + updated TensorArrays
|
||||
|
||||
Raises:
|
||||
TypeError: if dtype and packed_fn_values structure do not match
|
||||
ValueType: if dtype and packed_fn_values lengths do not match
|
||||
TypeError: if fn_output_signature and result_value structure don't match
|
||||
ValueType: if fn_output_signature and result_value lengths don't match
|
||||
"""
|
||||
packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
|
||||
packed_fn_values = fn(packed_values)
|
||||
nest.assert_same_structure(dtype or elems, packed_fn_values)
|
||||
flat_fn_values = output_flatten(packed_fn_values)
|
||||
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)]
|
||||
elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta]
|
||||
elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable,
|
||||
elems_flat_signature)
|
||||
elems_value = elems_unflatten(elems_value_flat)
|
||||
result_value = fn(elems_value)
|
||||
nest.assert_same_structure(fn_output_signature or elems, result_value)
|
||||
result_value_flat = nest.flatten(result_value)
|
||||
result_value_batchable = _result_value_flat_to_batchable(
|
||||
result_value_flat, result_flat_signature)
|
||||
tas = [
|
||||
ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable)
|
||||
]
|
||||
return (i + 1, tas)
|
||||
|
||||
_, r_a = control_flow_ops.while_loop(
|
||||
lambda i, _: i < n, compute, (i, accs_ta),
|
||||
lambda i, _: i < n,
|
||||
compute, (i, result_batchable_ta),
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory,
|
||||
maximum_iterations=n)
|
||||
results_flat = [r.stack() for r in r_a]
|
||||
result_batchable = [r.stack() for r in r_a]
|
||||
|
||||
n_static = tensor_shape.Dimension(tensor_shape.dimension_value(
|
||||
elems_flat[0].get_shape().with_rank_at_least(1)[0]))
|
||||
for elem in elems_flat[1:]:
|
||||
n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value(
|
||||
elem.get_shape().with_rank_at_least(1)[0])))
|
||||
for r in results_flat:
|
||||
# Update each output tensor w/ static shape info about the outer dimension.
|
||||
for r in result_batchable:
|
||||
r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
|
||||
r.get_shape()[1:]))
|
||||
|
||||
@ -284,7 +499,103 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True,
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
|
||||
return output_pack(results_flat)
|
||||
result_flat = _result_batchable_to_flat(result_batchable,
|
||||
result_flat_signature)
|
||||
result = result_unflatten(result_flat)
|
||||
return result
|
||||
|
||||
|
||||
def _dtype_to_spec(d):
|
||||
if not isinstance(d, type_spec.TypeSpec):
|
||||
d = tensor_spec.TensorSpec(None, d)
|
||||
return d
|
||||
|
||||
|
||||
def _most_general_compatible_type(spec):
|
||||
"""Returns the most general TypeSpec compatible with `spec`."""
|
||||
# TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
|
||||
if isinstance(spec, tensor_spec.TensorSpec):
|
||||
return tensor_spec.TensorSpec(None, spec.dtype)
|
||||
elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
|
||||
# pylint: disable=protected-access
|
||||
return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
|
||||
spec._row_splits_dtype)
|
||||
elif isinstance(spec, sparse_tensor.SparseTensorSpec):
|
||||
# pylint: disable=protected-access
|
||||
return sparse_tensor.SparseTensorSpec(None, spec.dtype)
|
||||
else:
|
||||
return spec
|
||||
|
||||
|
||||
def _result_flat_signature_to_batchable_dtype(result_flat_signature):
|
||||
"""Converts result_flat_signature -> result_batchable_dtype."""
|
||||
components = []
|
||||
for spec in result_flat_signature:
|
||||
if not isinstance(spec, type_spec.BatchableTypeSpec):
|
||||
raise TypeError("map_fn can not generate %s outputs" % (spec,))
|
||||
# pylint: disable=protected-access
|
||||
components.extend([s.dtype for s in spec._flat_tensor_specs])
|
||||
return components
|
||||
|
||||
|
||||
def _elems_flat_to_batchable(elems_flat):
|
||||
"""Converts elems_flat -> elems_batchable."""
|
||||
elems_batchable = []
|
||||
for elems_tensor in elems_flat:
|
||||
spec = type_spec.type_spec_from_value(elems_tensor)
|
||||
if not isinstance(spec, type_spec.BatchableTypeSpec):
|
||||
raise TypeError("map_fn can not consume %s inputs: got %r" %
|
||||
(spec, elems_tensor))
|
||||
# pylint: disable=protected-access
|
||||
elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor))
|
||||
return elems_batchable
|
||||
|
||||
|
||||
def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature):
|
||||
"""Converts elems_value_batchable -> elems_value_flat."""
|
||||
elems_value_flat = []
|
||||
i = 0
|
||||
for spec in elems_flat_signature:
|
||||
# pylint: disable=protected-access
|
||||
spec = spec._unbatch()
|
||||
tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)]
|
||||
elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list))
|
||||
i += len(tensor_list)
|
||||
assert i == len(elems_value_batchable)
|
||||
return elems_value_flat
|
||||
|
||||
|
||||
def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
|
||||
"""Converts result_value_flat -> result_value_batchable."""
|
||||
result_value_batchable = []
|
||||
for (r_value, r_spec) in zip(result_value_flat, result_flat_signature):
|
||||
if isinstance(r_spec, tensor_spec.TensorSpec):
|
||||
result_value_batchable.append(r_value)
|
||||
else:
|
||||
if not r_spec.is_compatible_with(r_value):
|
||||
raise ValueError(
|
||||
"Error in map_fn:\n Expected `fn` to return a:\n %s\n"
|
||||
" But it returned a:\n %s\n (value=%s)\n"
|
||||
" To fix, update the `fn_output_signature` (or `dtype`) "
|
||||
"argument to `map_fn`." %
|
||||
(r_spec, type_spec.type_spec_from_value(r_value), r_value))
|
||||
result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access
|
||||
return result_value_batchable
|
||||
|
||||
|
||||
def _result_batchable_to_flat(result_batchable, result_flat_signature):
|
||||
"""Converts result_batchable -> result_flat."""
|
||||
result_flat = []
|
||||
i = 0
|
||||
for spec in result_flat_signature:
|
||||
# pylint: disable=protected-access
|
||||
num_tensors = len(spec._flat_tensor_specs)
|
||||
result_flat.append(
|
||||
spec._batch(None)._from_compatible_tensor_list(
|
||||
result_batchable[i:i + num_tensors]))
|
||||
i += num_tensors
|
||||
assert i == len(result_batchable)
|
||||
return result_flat
|
||||
|
||||
|
||||
@tf_export("map_fn", v1=[])
|
||||
@ -297,6 +608,7 @@ Use:
|
||||
results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""",
|
||||
warn_once=True,
|
||||
back_prop=False)
|
||||
@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype")
|
||||
def map_fn_v2(fn,
|
||||
elems,
|
||||
dtype=None,
|
||||
@ -304,122 +616,25 @@ def map_fn_v2(fn,
|
||||
back_prop=True,
|
||||
swap_memory=False,
|
||||
infer_shape=True,
|
||||
name=None):
|
||||
"""map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
The simplest version of `map_fn` repeatedly applies the callable `fn` to a
|
||||
sequence of elements from first to last. The elements are made of the
|
||||
tensors unpacked from `elems`. `dtype` is the data type of the return
|
||||
value of `fn`. Users must provide `dtype` if it is different from
|
||||
the data type of `elems`.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[values.shape[0]] + fn(values[0]).shape`.
|
||||
|
||||
This method also allows multi-arity `elems` and output of `fn`. If `elems`
|
||||
is a (possibly nested) list or tuple of tensors, then each of these tensors
|
||||
must have a matching first (unpack) dimension. The signature of `fn` may
|
||||
match the structure of `elems`. That is, if `elems` is
|
||||
`(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
|
||||
`fn = lambda (t1, [t2, t3, [t4, t5]]):`.
|
||||
|
||||
Furthermore, `fn` may emit a different structure than its input. For example,
|
||||
`fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case,
|
||||
the `dtype` parameter is not optional: `dtype` must be a type or (possibly
|
||||
nested) tuple of types matching the output of `fn`.
|
||||
|
||||
To apply a functional operation to the nonzero elements of a SparseTensor
|
||||
one of the following methods is recommended. First, if the function is
|
||||
expressible as TensorFlow ops, use
|
||||
|
||||
```python
|
||||
result = SparseTensor(input.indices, fn(input.values), input.dense_shape)
|
||||
```
|
||||
|
||||
If, however, the function is not expressible as a TensorFlow op, then use
|
||||
|
||||
```python
|
||||
result = SparseTensor(
|
||||
input.indices, map_fn(fn, input.values), input.dense_shape)
|
||||
```
|
||||
|
||||
instead.
|
||||
|
||||
When executing eagerly, map_fn does not execute in parallel even if
|
||||
`parallel_iterations` is set to a value > 1. You can still get the
|
||||
performance benefits of running a function in parallel by using the
|
||||
`tf.function` decorator,
|
||||
|
||||
```python
|
||||
# Assume the function being used in map_fn is fn.
|
||||
# To ensure map_fn calls fn in parallel, use the tf.function decorator.
|
||||
@tf.function
|
||||
def func(tensor):
|
||||
return tf.map_fn(fn, tensor)
|
||||
```
|
||||
|
||||
Note that if you use the `tf.function` decorator, any non-TensorFlow Python
|
||||
code that you may have written in your function won't get executed. See
|
||||
[`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for
|
||||
more details. The recommendation would be to debug without `tf.function` but
|
||||
switch to it to get performance benefits of running `map_fn` in parallel.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed. It accepts one argument, which will have
|
||||
the same (possibly nested) structure as `elems`. Its output must have the
|
||||
same structure as `dtype` if one is provided, otherwise it must have the
|
||||
same structure as `elems`.
|
||||
elems: A tensor or (possibly nested) sequence of tensors, each of which will
|
||||
be unpacked along their first dimension. The nested sequence of the
|
||||
resulting slices will be applied to `fn`.
|
||||
dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure
|
||||
of Tensors differing from the structure of `elems`, then `dtype` is not
|
||||
optional and must have the same structure as the output of `fn`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run in
|
||||
parallel. When graph building, the default value is 10. While executing
|
||||
eagerly, the default value is set to 1.
|
||||
back_prop: (optional) Deprecated. False disables support for back
|
||||
propagation. Prefer using `tf.stop_gradient` instead.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
infer_shape: (optional) False disables tests for consistent output shapes.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor or (possibly nested) sequence of tensors. Each tensor packs the
|
||||
results of applying `fn` to tensors unpacked from `elems` along the first
|
||||
dimension, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable or the structure of the output of
|
||||
`fn` and `dtype` do not match, or if elems is a SparseTensor.
|
||||
ValueError: if the lengths of the output of `fn` and `dtype` do not match.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
elems = np.array([1, 2, 3, 4, 5, 6])
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
|
||||
```python
|
||||
elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
|
||||
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
|
||||
# alternate == [-1, 2, -3]
|
||||
```
|
||||
|
||||
```python
|
||||
elems = np.array([1, 2, 3])
|
||||
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
|
||||
# alternates[0] == [1, 2, 3]
|
||||
# alternates[1] == [-1, -2, -3]
|
||||
```
|
||||
"""
|
||||
name=None,
|
||||
fn_output_signature=None):
|
||||
"""Transform `elems` by applying `fn` to each element unstacked on axis 0."""
|
||||
if fn_output_signature is None:
|
||||
fn_output_signature = dtype
|
||||
return map_fn(
|
||||
fn=fn,
|
||||
elems=elems,
|
||||
dtype=dtype,
|
||||
fn_output_signature=fn_output_signature,
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory,
|
||||
infer_shape=infer_shape,
|
||||
name=name)
|
||||
|
||||
|
||||
# Docstring for v2 is the same as v1, except that back_prop is deprecated.
|
||||
map_fn_v2.__doc__ = re.sub(
|
||||
r"( back_prop: \(optional\) )(.*)",
|
||||
r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2",
|
||||
map_fn.__doc__)
|
||||
assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__
|
||||
|
@ -46,13 +46,16 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
dict(
|
||||
fn=mo.reduce_mean,
|
||||
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
||||
elems_dtype=dtypes.int32,
|
||||
expected_output=[2, 4, 6],
|
||||
result_dtype=dtypes.int32,
|
||||
),
|
||||
dict(
|
||||
fn=string_ops.reduce_join,
|
||||
elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']],
|
||||
expected_output=[b'foobarbaz', b'a', b'bc'],
|
||||
dtype=dtypes.string,
|
||||
elems_dtype=dtypes.string,
|
||||
result_dtype=dtypes.string,
|
||||
),
|
||||
# [d1, (d2)] -> [d1, 2]
|
||||
dict(
|
||||
@ -60,7 +63,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
# fn=self.stack_mean_and_sum,
|
||||
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
||||
expected_output=[[2, 6], [4.5, 9], [6.5, 13]],
|
||||
dtype=dtypes.float32,
|
||||
elems_dtype=dtypes.float32,
|
||||
result_dtype=dtypes.float32,
|
||||
expected_ragged_rank=0,
|
||||
),
|
||||
# [d1, (d2)] -> [d1, (d2)]
|
||||
@ -68,7 +72,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
fn=lambda x: x + np.int64(1),
|
||||
elems=[[1, 2, 3], [4, 5], [6, 7]],
|
||||
expected_output=[[2, 3, 4], [5, 6], [7, 8]],
|
||||
dtype=dtypes.int64,
|
||||
elems_dtype=dtypes.int64,
|
||||
result_dtype=ragged_tensor.RaggedTensorType(
|
||||
dtype=dtypes.int64, ragged_rank=1),
|
||||
),
|
||||
@ -157,11 +161,11 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
expected_ragged_rank=None,
|
||||
result_ragged_rank=None,
|
||||
elems_ragged_rank=None,
|
||||
dtype=dtypes.int64,
|
||||
elems_dtype=dtypes.int64,
|
||||
result_dtype=None,
|
||||
infer_shape=False,
|
||||
infer_shape=True,
|
||||
):
|
||||
elems = ragged_factory_ops.constant(elems, dtype, elems_ragged_rank)
|
||||
elems = ragged_factory_ops.constant(elems, elems_dtype, elems_ragged_rank)
|
||||
output = ragged_map_ops.map_fn(
|
||||
fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape)
|
||||
|
||||
@ -260,8 +264,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
def testMismatchRaggedRank(self):
|
||||
elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]])
|
||||
fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, r'The declared ragged rank (23) mismatches the result (1)'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
|
||||
_ = ragged_map_ops.map_fn(
|
||||
fn,
|
||||
elems,
|
||||
@ -271,8 +275,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
def testMismatchRaggedRank2(self):
|
||||
elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]])
|
||||
fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0])
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError, r'The declared ragged rank (10) mismatches the result (2)'):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'(?s)Expected `fn` to return.*But it returned.*'):
|
||||
_ = ragged_map_ops.map_fn(
|
||||
fn,
|
||||
elems,
|
||||
|
@ -17,22 +17,14 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import collections
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.ops.ragged import ragged_config
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
|
||||
map_fn_lib = LazyLoader(
|
||||
"map_fn_lib", globals(),
|
||||
"tensorflow.python.ops.map_fn")
|
||||
|
||||
|
||||
def map_fn(fn,
|
||||
@ -166,298 +158,25 @@ def map_fn(fn,
|
||||
# out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]])
|
||||
```
|
||||
"""
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
if isinstance(elems, sparse_tensor.SparseTensor):
|
||||
raise TypeError(
|
||||
"To perform a map on the values of a sparse tensor use either "
|
||||
" SparseTensor(input.indices, fn(input.values), input.dense_shape) or "
|
||||
" SparseTensor(input.indices, map_fn(fn, input.values), "
|
||||
"input.dense_shape)")
|
||||
|
||||
in_graph_mode = not context.executing_eagerly()
|
||||
# Set the default number of parallel_iterations depending on graph/eager mode.
|
||||
if in_graph_mode and not parallel_iterations:
|
||||
parallel_iterations = 10
|
||||
elif not in_graph_mode and not parallel_iterations:
|
||||
parallel_iterations = 1
|
||||
|
||||
if not in_graph_mode and parallel_iterations > 1:
|
||||
logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no "
|
||||
"effect when executing eagerly. Consider calling map_fn"
|
||||
" with tf.contrib.eager.defun to execute fn in "
|
||||
"parallel.", 1)
|
||||
parallel_iterations = 1
|
||||
|
||||
input_is_sequence = nest.is_sequence(elems)
|
||||
input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
|
||||
|
||||
def input_pack(x):
|
||||
return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
|
||||
|
||||
elems_flat = input_flatten(elems)
|
||||
elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat)
|
||||
|
||||
with ops.name_scope(name, "map", elems_flat):
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode:
|
||||
# Any get_variable calls in fn will cache the first call locally
|
||||
# and not issue repeated network I/O requests for each iteration.
|
||||
varscope = vs.get_variable_scope()
|
||||
varscope_caching_device_was_none = False
|
||||
if varscope.caching_device is None:
|
||||
# TODO(ebrevdo): Change to using colocate_with here and in other
|
||||
# methods.
|
||||
varscope.set_caching_device(lambda op: op.device)
|
||||
varscope_caching_device_was_none = True
|
||||
|
||||
elems_flat = [
|
||||
ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem")
|
||||
for elem in elems_flat
|
||||
]
|
||||
|
||||
# We can either infer the output, or we can assume that it will be the same
|
||||
# as the input structure.
|
||||
dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
|
||||
|
||||
# Find the number of iterations, n may be known statically.
|
||||
if isinstance(elems_flat[0], ragged_tensor.RaggedTensor):
|
||||
n = elems_flat[0].nrows(out_type=dtypes.int32)
|
||||
else:
|
||||
static_shape = elems_flat[0].shape
|
||||
if static_shape.ndims is not None and static_shape.ndims < 1:
|
||||
if len(elems_flat) == 1:
|
||||
raise ValueError(
|
||||
"elems must be a 1+ dimensional Tensor, not a scalar")
|
||||
else:
|
||||
raise ValueError(
|
||||
"elements in elems must be 1+ dimensional Tensors, not scalars")
|
||||
n = (tensor_shape.dimension_value(static_shape[0]) or
|
||||
array_ops.shape(elems_flat[0])[0])
|
||||
|
||||
n = math_ops.cast(n, dtype=dtypes.int32)
|
||||
# Create a flat list of TAs.
|
||||
|
||||
# Flatten the dtype structure to a list.
|
||||
dtype_flat = nest.flatten(dtype)
|
||||
|
||||
# decompose to components
|
||||
dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat]
|
||||
dtype_components_flat = nest.flatten(dtype_components)
|
||||
|
||||
# Create TensorArrays.
|
||||
accs_ta = [
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n)
|
||||
for t in dtype_components_flat
|
||||
]
|
||||
|
||||
i = constant_op.constant(0, dtype=dtypes.int32)
|
||||
|
||||
def compute(i, tas):
|
||||
"""The loop body of map_fn.
|
||||
|
||||
Args:
|
||||
i: the loop counter
|
||||
tas: the flat TensorArray accumulator list
|
||||
|
||||
Returns:
|
||||
(i + 1, tas): the updated counter + updated TensorArrays
|
||||
|
||||
Raises:
|
||||
TypeError: if dtype and packed_fn_values structure do not match
|
||||
ValueType: if dtype and packed_fn_values lengths do not match
|
||||
"""
|
||||
# Get Tensors or RaggedTensors sliced at i, then pack it back to the
|
||||
# original structure.
|
||||
packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat])
|
||||
packed_fn_values = fn(packed_values)
|
||||
|
||||
# Check that the structure of the output matches what was declared or
|
||||
# inferred.
|
||||
# nest.assert_same_structure(dtype or elems, packed_fn_values)
|
||||
|
||||
# Flatten and decompose to a list of Tensors
|
||||
flat_fn_values = nest.flatten(packed_fn_values)
|
||||
|
||||
# If we declared that we are expecting a RaggedTensor output, but we get a
|
||||
# Tensor output. We should try to convert it to a RaggedTensor.
|
||||
flat_fn_composite_tensors = list(
|
||||
_convert_declared(flat_fn_values, dtype_flat))
|
||||
|
||||
flat_fn_components = [
|
||||
_maybe_decompose_tensor(t) for t in flat_fn_composite_tensors
|
||||
]
|
||||
flat_fn_tensors = nest.flatten(flat_fn_components)
|
||||
|
||||
# Write to TAs.
|
||||
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)]
|
||||
|
||||
return (i + 1, tas)
|
||||
|
||||
_, r_a = control_flow_ops.while_loop(
|
||||
lambda i, _: i < n, compute, (i, accs_ta),
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory,
|
||||
maximum_iterations=n)
|
||||
|
||||
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
|
||||
# supported in Eager
|
||||
if in_graph_mode and varscope_caching_device_was_none:
|
||||
varscope.set_caching_device(None)
|
||||
|
||||
# Pack back into a list of components
|
||||
results_as_components = nest.pack_sequence_as(dtype_components, r_a)
|
||||
|
||||
# Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs.
|
||||
def _stack_or_concat(e):
|
||||
if isinstance(e, _RaggedTensorComponents):
|
||||
return _concat_ragged_tensor_components(e)
|
||||
else:
|
||||
result = e.stack()
|
||||
return result
|
||||
|
||||
results_flat_components = [
|
||||
_stack_or_concat(e) for e in results_as_components
|
||||
]
|
||||
|
||||
results_packed = [
|
||||
_maybe_recompose_tensor(c) for c in results_flat_components
|
||||
]
|
||||
results_packed = nest.pack_sequence_as(dtype, results_packed)
|
||||
return results_packed
|
||||
if dtype is None:
|
||||
dtype = nest.map_structure(lambda e: e.dtype, elems)
|
||||
dtype = nest.map_structure(_ragged_type_to_spec, dtype)
|
||||
return map_fn_lib.map_fn(fn,
|
||||
elems,
|
||||
dtype,
|
||||
parallel_iterations,
|
||||
back_prop,
|
||||
swap_memory,
|
||||
infer_shape,
|
||||
name)
|
||||
|
||||
|
||||
class _RaggedTensorComponents(
|
||||
collections.namedtuple(
|
||||
"_RaggedTensorComponents",
|
||||
["flat_values", "nested_row_lengths", "outer_row_length"])):
|
||||
"""A namedtuple of components which represent a `RaggedTensor`.
|
||||
|
||||
_RaggedTensorComponents is a list of components which can be used to create a
|
||||
`RaggedTensor`. Use this class to represent a `RaggedTensor` in situations
|
||||
where nest.flatten and nest.pack_sequence_as should decompose ragged tensors
|
||||
into their components..
|
||||
|
||||
The following are a list of components for a `RaggedTensor`:
|
||||
|
||||
flat_values: The flat and inner values of a RaggedTensor. This could be
|
||||
a `Tensor`, a `TensorArray`, or a data type.
|
||||
nested_row_lengths: a tuple containing the row lengths of each rank. The
|
||||
elements of the tuple could be `Tensor`s or `TensorArray`s.
|
||||
outer_row_length: a `Tensor` or `TensorArray` containing the row length of the
|
||||
`RaggedTensor`'s outermost dimension.
|
||||
|
||||
See `RaggedTensor` for more details of the use of each component.
|
||||
"""
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
def _concat_ragged_tensor_components(rt_ta):
|
||||
flat_values = rt_ta.flat_values.concat()
|
||||
nested_row_lengths = tuple(
|
||||
row_lengths_ta.concat() for row_lengths_ta in rt_ta.nested_row_lengths)
|
||||
outer_row_length = rt_ta.outer_row_length.concat()
|
||||
return _RaggedTensorComponents(
|
||||
flat_values=flat_values,
|
||||
nested_row_lengths=nested_row_lengths,
|
||||
outer_row_length=outer_row_length)
|
||||
|
||||
|
||||
def _maybe_decompose_tensor(rt):
|
||||
"""Decompose tensors to their composite tensors."""
|
||||
if not isinstance(rt, ragged_tensor.RaggedTensor):
|
||||
return rt
|
||||
|
||||
# The three component pieces we need:
|
||||
# - inner values
|
||||
flat_values = rt.flat_values
|
||||
|
||||
# - row_splits of the RT
|
||||
splits = rt.nested_row_splits
|
||||
nested_row_lengths = tuple(split[1:] - split[:-1] for split in splits)
|
||||
|
||||
# - outer row length
|
||||
outer_row_length = array_ops.expand_dims(rt.nrows(), axis=0)
|
||||
|
||||
return _RaggedTensorComponents(
|
||||
flat_values=flat_values,
|
||||
nested_row_lengths=nested_row_lengths,
|
||||
outer_row_length=outer_row_length,
|
||||
)
|
||||
|
||||
|
||||
def _maybe_recompose_tensor(t):
|
||||
"""Reconstructs a _RaggedTensorComponents into a RaggedTensor."""
|
||||
if not isinstance(t, _RaggedTensorComponents):
|
||||
return t
|
||||
|
||||
values = t.flat_values
|
||||
nested_row_lengths = tuple(t.nested_row_lengths)
|
||||
for nested_row_length in reversed(nested_row_lengths):
|
||||
values = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
values, nested_row_length, validate=False)
|
||||
return ragged_tensor.RaggedTensor.from_row_lengths(values, t.outer_row_length,
|
||||
validate=False)
|
||||
|
||||
|
||||
def _maybe_decompose_dtype(d):
|
||||
"""Decompose dtypes into composite tensors (if necessary)."""
|
||||
if not isinstance(d, ragged_tensor.RaggedTensorType):
|
||||
return d
|
||||
|
||||
result = _RaggedTensorComponents(
|
||||
flat_values=d.dtype,
|
||||
nested_row_lengths=tuple(
|
||||
d.row_splits_dtype for i in range(d.ragged_rank - 1)),
|
||||
outer_row_length=d.row_splits_dtype,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _convert_declared(fn_output_flat, output_declared):
|
||||
"""Convert outputs which are `Tensor`s into `_RaggedTensorComponents`."""
|
||||
for current, declared in zip(fn_output_flat, output_declared):
|
||||
if isinstance(declared, ragged_tensor.RaggedTensorType):
|
||||
yield _convert_declared_ragged(current, declared)
|
||||
else:
|
||||
yield current
|
||||
|
||||
|
||||
def _convert_declared_ragged(current, declared):
|
||||
"""Converts an output with RaggedTensorType into a _RaggedTensorComponents."""
|
||||
# Check that the ragged ranks match up.
|
||||
# + 1 to account for the rank of the outermost dimension.
|
||||
current_ragged_rank = getattr(current, "ragged_rank", 0)
|
||||
if declared.ragged_rank != current_ragged_rank + 1:
|
||||
raise ValueError(
|
||||
"The declared ragged rank (%d) mismatches the result (%d)" %
|
||||
(declared.ragged_rank, current_ragged_rank + 1))
|
||||
|
||||
# Check that dtypes match up.
|
||||
if declared.dtype != current.dtype:
|
||||
raise ValueError(
|
||||
"The declared dtype (%s) mismatches the result (%s)" %
|
||||
(declared.dtype, current.dtype))
|
||||
if (isinstance(current, ragged_tensor.RaggedTensor) and
|
||||
declared.row_splits_dtype != current.row_splits.dtype):
|
||||
if not ragged_config.auto_cast_partition_dtype():
|
||||
raise ValueError(
|
||||
"The declared row_splits dtype (%s) mismatches the result (%s)."
|
||||
" Use RaggedTensor.with_row_splits_dtype to convert it."
|
||||
% (declared.row_splits_dtype, current.row_splits.dtype))
|
||||
current = current.with_row_splits_dtype(declared.row_splits_dtype)
|
||||
|
||||
if isinstance(current, ragged_tensor.RaggedTensor):
|
||||
return current
|
||||
def _ragged_type_to_spec(t):
|
||||
if isinstance(t, ragged_tensor.RaggedTensorType):
|
||||
# Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the
|
||||
# type for the mapped `fn` output, but RaggedTensorType gives the type for
|
||||
# the result of stacking the mapped `fn` outputs.
|
||||
return ragged_tensor.RaggedTensorSpec(
|
||||
None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype)
|
||||
else:
|
||||
nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0]
|
||||
row_length = array_ops.expand_dims(nrows, axis=0)
|
||||
return _RaggedTensorComponents(
|
||||
flat_values=current,
|
||||
nested_row_lengths=(),
|
||||
outer_row_length=row_length)
|
||||
|
||||
return t
|
||||
|
@ -2330,6 +2330,10 @@ class RaggedTensorType(object):
|
||||
ragged_rank = property(lambda self: self._ragged_rank)
|
||||
row_splits_dtype = property(lambda self: self._row_splits_dtype)
|
||||
|
||||
def __repr__(self):
|
||||
return "RaggedTensorType(%r, %r, %r)" % (
|
||||
self.dtype, self.ragged_rank, self.row_splits_dtype)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Helper Functions
|
||||
|
@ -1626,7 +1626,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "map_fn"
|
||||
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "matching_files"
|
||||
|
@ -786,7 +786,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "map_fn"
|
||||
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "matmul"
|
||||
|
Loading…
Reference in New Issue
Block a user