diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index 96810805c18..516f427ad08 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -108,31 +108,29 @@ def map_fn(fn, `fn_output_signature` can be specified using any of the following: - * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) - * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) - * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) - * A (possibly nested) tuple, list, or dict containing the above types. + * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) + * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) + * A `tf.SparseTensorSpec` (to describe a `tf.sparse.SparseTensor`) + * A (possibly nested) tuple, list, or dict containing the above types. #### RaggedTensors `map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular: - * If `elems` is a `RaggedTensor`, then `fn` will be called with each - row of that ragged tensor. + * If `elems` is a `RaggedTensor`, then `fn` will be called with each + row of that ragged tensor. + * If `elems` has only one ragged dimension, then the values passed to + `fn` will be `tf.Tensor`s. + * If `elems` has multiple ragged dimensions, then the values passed to + `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension. - * If `elems` has only one ragged dimension, then the values passed to - `fn` will be `tf.Tensor`s. - * If `elems` has multiple ragged dimensions, then the values passed to - `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension. - - * If the result of `map_fn` should be a `RaggedTensor`, then use a - `tf.RaggedTensorSpec` to specify `fn_output_signature`. - - * If `fn` returns `tf.Tensor`s with varying sizes, then use a - `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a - single ragged tensor (which will have ragged_rank=1). - * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec` - with the same `ragged_rank`. + * If the result of `map_fn` should be a `RaggedTensor`, then use a + `tf.RaggedTensorSpec` to specify `fn_output_signature`. + * If `fn` returns `tf.Tensor`s with varying sizes, then use a + `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a + single ragged tensor (which will have ragged_rank=1). + * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec` + with the same `ragged_rank`. >>> # Example: RaggedTensor input >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) @@ -150,10 +148,10 @@ def map_fn(fn, *rows* of a `RaggedTensor`. If you wish to map a function over the individual values, then you should use: - * `tf.ragged.map_flat_values(fn, rt)` - (if fn is expressible as TensorFlow ops) - * `rt.with_flat_values(map_fn(fn, rt.flat_values))` - (otherwise) + * `tf.ragged.map_flat_values(fn, rt)` + (if fn is expressible as TensorFlow ops) + * `rt.with_flat_values(map_fn(fn, rt.flat_values))` + (otherwise) E.g.: @@ -165,14 +163,14 @@ def map_fn(fn, `map_fn` supports `tf.sparse.SparseTensor` inputs and outputs. In particular: - * If `elems` is a `SparseTensor`, then `fn` will be called with each row - of that sparse tensor. In particular, the value passed to `fn` will be a - `tf.sparse.SparseTensor` with one fewer dimension than `elems`. + * If `elems` is a `SparseTensor`, then `fn` will be called with each row + of that sparse tensor. In particular, the value passed to `fn` will be a + `tf.sparse.SparseTensor` with one fewer dimension than `elems`. - * If the result of `map_fn` should be a `SparseTensor`, then use a - `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual - `SparseTensor`s returned by `fn` will be stacked into a single - `SparseTensor` with one more dimension. + * If the result of `map_fn` should be a `SparseTensor`, then use a + `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual + `SparseTensor`s returned by `fn` will be stacked into a single + `SparseTensor` with one more dimension. >>> # Example: SparseTensor input >>> st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) @@ -195,15 +193,15 @@ def map_fn(fn, *rows* of a `SparseTensor`. If you wish to map a function over the nonzero values, then you should use: - * If the function is expressible as TensorFlow ops, use: - ```python - tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) - ``` - * Otherwise, use: - ```python - tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), - st.dense_shape) - ``` + * If the function is expressible as TensorFlow ops, use: + ```python + tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape) + ``` + * Otherwise, use: + ```python + tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values), + st.dense_shape) + ``` #### `map_fn` vs. vectorized operations @@ -215,14 +213,14 @@ def map_fn(fn, `map_fn` should typically only be used if one of the following is true: - * It is difficult or expensive to express the desired transform with - vectorized operations. - * `fn` creates large intermediate values, so an equivalent vectorized - transform would take too much memory. - * Processing elements in parallel is more efficient than an equivalent - vectorized transform. - * Efficiency of the transform is not critical, and using `map_fn` is - more readable. + * It is difficult or expensive to express the desired transform with + vectorized operations. + * `fn` creates large intermediate values, so an equivalent vectorized + transform would take too much memory. + * Processing elements in parallel is more efficient than an equivalent + vectorized transform. + * Efficiency of the transform is not critical, and using `map_fn` is + more readable. E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)` across `elems` could be rewritten more efficiently using vectorized ops: @@ -255,7 +253,7 @@ def map_fn(fn, [2, 3, 4]], dtype=int32)> - Note that if you use the `tf.function` decorator, any non-TensorFlow Python + Note: if you use the `tf.function` decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See `tf.function` for more details. The recommendation would be to debug without `tf.function` but switch to it to get performance benefits of running `map_fn`