diff --git a/tensorflow/python/kernel_tests/map_fn_test.py b/tensorflow/python/kernel_tests/map_fn_test.py index d2b1d433c78..1e10d689886 100644 --- a/tensorflow/python/kernel_tests/map_fn_test.py +++ b/tensorflow/python/kernel_tests/map_fn_test.py @@ -69,13 +69,14 @@ class MapFnTest(test.TestCase): def testMapSparseTensor(self): with self.cached_session(): - with self.assertRaises(TypeError): - map_fn.map_fn( - lambda x: x, - sparse_tensor.SparseTensor( - indices=[[0, 0], [0, 1], [1, 0]], - values=constant_op.constant([0, 1, 2]), - dense_shape=[2, 2])) + st = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=constant_op.constant([0, 1, 2]), + dense_shape=[2, 2]) + result = map_fn.map_fn(lambda x: x, st) + self.assertAllEqual(result.indices, st.indices) + self.assertAllEqual(result.values, st.values) + self.assertAllEqual(result.dense_shape, st.dense_shape) @test_util.run_in_graph_and_eager_modes def testMapOverScalarErrors(self): diff --git a/tensorflow/python/ops/map_fn.py b/tensorflow/python/ops/map_fn.py index 457ab309435..7438e584227 100644 --- a/tensorflow/python/ops/map_fn.py +++ b/tensorflow/python/ops/map_fn.py @@ -20,15 +20,20 @@ from __future__ import division from __future__ import print_function +import re + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -36,134 +41,324 @@ from tensorflow.python.util.tf_export import tf_export @tf_export(v1=["map_fn"]) -def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, - swap_memory=False, infer_shape=True, name=None): - """map on the list of tensors unpacked from `elems` on dimension 0. +@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") +def map_fn(fn, + elems, + dtype=None, + parallel_iterations=None, + back_prop=True, + swap_memory=False, + infer_shape=True, + name=None, + fn_output_signature=None): + """Transforms `elems` by applying `fn` to each element unstacked on axis 0. - The simplest version of `map_fn` repeatedly applies the callable `fn` to a - sequence of elements from first to last. The elements are made of the - tensors unpacked from `elems`. `dtype` is the data type of the return - value of `fn`. Users must provide `dtype` if it is different from - the data type of `elems`. + `map_fn` unstacks `elems` on axis 0 to obtain a sequence of elements; + calls `fn` to transform each element; and then stacks the transformed + values back together. - Suppose that `elems` is unpacked into `values`, a list of tensors. The shape - of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. + #### Mapping functions with single-Tensor inputs and outputs - This method also allows multi-arity `elems` and output of `fn`. If `elems` - is a (possibly nested) list or tuple of tensors, then each of these tensors - must have a matching first (unpack) dimension. The signature of `fn` may - match the structure of `elems`. That is, if `elems` is - `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: - `fn = lambda (t1, [t2, t3, [t4, t5]]):`. + If `elems` is a single tensor and `fn`'s signature is `tf.Tensor->tf.Tensor`, + then `map_fn(fn, elems)` is equivalent to + `tf.stack([fn(elem) for elem in tf.unstack(elems)])`. E.g.: - Furthermore, `fn` may emit a different structure than its input. For example, - `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, - the `dtype` parameter is not optional: `dtype` must be a type or (possibly - nested) tuple of types matching the output of `fn`. + >>> tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2])) + - To apply a functional operation to the nonzero elements of a SparseTensor - one of the following methods is recommended. First, if the function is - expressible as TensorFlow ops, use + `map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape`. - ```python - result = SparseTensor(input.indices, fn(input.values), input.dense_shape) - ``` + #### Mapping functions with multi-arity inputs and outputs - If, however, the function is not expressible as a TensorFlow op, then use + `map_fn` also supports functions with multi-arity inputs and outputs: - ```python - result = SparseTensor( - input.indices, map_fn(fn, input.values), input.dense_shape) - ``` + * If `elems` is a tuple (or nested structure) of tensors, then those tensors + must all have the same outer-dimension size (`num_elems`); and `fn` is + used to transform each tuple (or structure) of corresponding slices from + `elems`. E.g., if `elems` is a tuple `(t1, t2, t3)`, then `fn` is used to + transform each tuple of slices `(t1[i], t2[i], t3[i])` + (where `0 <= i < num_elems`). - instead. + * If `fn` returns a tuple (or nested structure) of tensors, then the + result is formed by stacking corresponding elements from those structures. - When executing eagerly, map_fn does not execute in parallel even if + #### Specifying `fn`'s output signature + + If `fn`'s input and output signatures are different, then the output + signature must be specified using `fn_output_signature`. (The input and + output signatures are differ if their structures, dtypes, or tensor types do + not match). E.g.: + + >>> tf.map_fn(fn=tf.strings.length, # input & output have different dtypes + ... elems=tf.constant(["hello", "moon"]), + ... fn_output_signature=tf.int32) + + >>> tf.map_fn(fn=tf.strings.join, # input & output have different structures + ... elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])], + ... fn_output_signature=tf.string) + + + `fn_output_signature` can be specified using any of the following: + + * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) + * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) + * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`) + * A (possibly nested) tuple, list, or dict containing the above types. + + #### RaggedTensors + + `map_fn` supports `tf.RaggedTensor` inputs and outputs. In particular: + + * If `elems` is a `RaggedTensor`, then `fn` will be called with each + row of that ragged tensor. + + * If `elems` has only one ragged dimension, then the values passed to + `fn` will be `tf.Tensor`s. + * If `elems` has multiple ragged dimensions, then the values passed to + `fn` will be `tf.RaggedTensor`s with one fewer ragged dimension. + + * If the result of `map_fn` should be a `RaggedTensor`, then use a + `tf.RaggedTensorSpec` to specify `fn_output_signature`. + + * If `fn` returns `tf.Tensor`s with varying sizes, then use a + `tf.RaggedTensorSpec` with `ragged_rank=0` to combine them into a + single ragged tensor (which will have ragged_rank=1). + * If `fn` returns `tf.RaggedTensor`s, then use a `tf.RaggedTensorSpec` + with the same `ragged_rank`. + + >>> # Example: RaggedTensor input + >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) + >>> tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32) + + + >>> # Example: RaggedTensor output + >>> elems = tf.constant([3, 5, 0, 2]) + >>> tf.map_fn(tf.range, elems, + ... fn_output_signature=tf.RaggedTensorSpec(shape=[None], + ... dtype=tf.int32)) + + + Note: `map_fn` should only be used if you need to map a function over the + *rows* of a `RaggedTensor`. If you wish to map a function over the + individual values, then you should use: + + * `tf.ragged.map_flat_values(fn, rt)` + (if fn is expressible as TensorFlow ops) + * `rt.with_flat_values(map_fn(fn, rt.flat_values))` + (otherwise) + + E.g.: + + >>> rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]]) + >>> tf.ragged.map_flat_values(lambda x: x + 2, rt) + + + #### SparseTensors + + `map_fn` supports `tf.SparseTensor` inputs and outputs. In particular: + + * If `elems` is a `SparseTensor`, then `fn` will be called with each row + of that sparse tensor. In particular, the value passed to `fn` will be a + `tf.SparseTensor` with one fewer dimension than `elems`. + + * If the result of `map_fn` should be a `SparseTensor`, then use a + `tf.SparseTensorSpec` to specify `fn_output_signature`. The individual + `SparseTensor`s returned by `fn` will be stacked into a single + `SparseTensor` with one more dimension. + + >>> # Example: SparseTensor input + >>> st = tf.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4]) + >>> tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32) + + + >>> # Example: SparseTensor output + >>> tf.sparse.to_dense( + ... tf.map_fn(tf.sparse.eye, tf.constant([2, 3]), + ... fn_output_signature=tf.SparseTensorSpec(None, tf.float32))) + + + Note: `map_fn` should only be used if you need to map a function over the + *rows* of a `SparseTensor`. If you wish to map a function over the nonzero + values, then you should use: + + * `tf.SparseTensor(st.indices, fn(st.values), st.dense_shape)` + (if the function is expressible as TensorFlow ops) + * `tf.SparseTensor(st.indices, tf.map_fn(fn, st.values), st.dense_shape)` + (otherwise). + + #### `map_fn` vs. vectorized operations + + `map_fn` will apply the operations used by `fn` to each element of `elems`, + resulting in `O(elems.shape[0])` total operations. This is somewhat + mitigated by the fact that `map_fn` can process elements in parallel. + However, a transform expressed using `map_fn` is still typically less + efficient than an equivalent transform expressed using vectorized operations. + + `map_fn` should typically only be used if one of the following is true: + + * It is difficult or expensive to express the desired transform with + vectorized operations. + * `fn` creates large intermediate values, so an equivalent vectorized + transform would take too much memory. + * Processing elements in parallel is more efficient than an equivalent + vectorized transform. + * Efficiency of the transform is not critical, and using `map_fn` is + more readable. + + E.g., the example given above that maps `fn=lambda t: tf.range(t, t + 3)` + across `elems` could be rewritten more efficiently using vectorized ops: + + >>> elems = tf.constant([3, 5, 2]) + >>> tf.range(3) + tf.expand_dims(elems, 1) + + + In some cases, `tf.vectorized_map` can be used to automatically convert a + function to a vectorized eqivalent. + + #### Eager execution + + When executing eagerly, `map_fn` does not execute in parallel even if `parallel_iterations` is set to a value > 1. You can still get the performance benefits of running a function in parallel by using the - `tf.function` decorator, + `tf.function` decorator: + + >>> fn=lambda t: tf.range(t, t + 3) + >>> @tf.function + ... def func(elems): + ... return tf.map_fn(fn, elems, parallel_iterations=3) + >>> func(tf.constant([3, 5, 2])) + - ```python - # Assume the function being used in map_fn is fn. - # To ensure map_fn calls fn in parallel, use the tf.function decorator. - @tf.function - def func(tensor): - return tf.map_fn(fn, tensor) - ``` Note that if you use the `tf.function` decorator, any non-TensorFlow Python code that you may have written in your function won't get executed. See - [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for - more details. The recommendation would be to debug without `tf.function` but - switch to it to get performance benefits of running `map_fn` in parallel. + `tf.function` for more details. The recommendation would be to debug without + `tf.function` but switch to it to get performance benefits of running `map_fn` + in parallel. Args: - fn: The callable to be performed. It accepts one argument, which will - have the same (possibly nested) structure as `elems`. Its output - must have the same structure as `dtype` if one is provided, otherwise - it must have the same structure as `elems`. - elems: A tensor or (possibly nested) sequence of tensors, each of which - will be unpacked along their first dimension. The nested sequence - of the resulting slices will be applied to `fn`. - dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure - of Tensors differing from the structure of `elems`, then `dtype` is not - optional and must have the same structure as the output of `fn`. - parallel_iterations: (optional) The number of iterations allowed to run - in parallel. When graph building, the default value is 10. While executing + fn: The callable to be performed. It accepts one argument, which will have + the same (possibly nested) structure as `elems`. Its output must have the + same structure as `fn_output_signature` if one is provided; otherwise it + must have the same structure as `elems`. + elems: A tensor or (possibly nested) sequence of tensors, each of which will + be unstacked along their first dimension. `fn` will be applied to the + nested sequence of the resulting slices. `elems` may include ragged and + sparse tensors. + dtype: Deprecated: Equivalent to `fn_output_signature`. + parallel_iterations: (optional) The number of iterations allowed to run in + parallel. When graph building, the default value is 10. While executing eagerly, the default value is set to 1. - back_prop: (optional) True enables support for back propagation. + back_prop: (optional) False disables support for back propagation. swap_memory: (optional) True enables GPU-CPU memory swapping. infer_shape: (optional) False disables tests for consistent output shapes. name: (optional) Name prefix for the returned tensors. + fn_output_signature: The output signature of `fn`. Must be specified if + `fn`'s input and output signatures are different (i.e., if their + structures, dtypes, or tensor types do not match). + `fn_output_signature` can be specified using any of the following: + + * A `tf.DType` or `tf.TensorSpec` (to describe a `tf.Tensor`) + * A `tf.RaggedTensorSpec` (to describe a `tf.RaggedTensor`) + * A `tf.SparseTensorSpec` (to describe a `tf.SparseTensor`) + * A (possibly nested) tuple, list, or dict containing the above types. Returns: - A tensor or (possibly nested) sequence of tensors. Each tensor packs the - results of applying `fn` to tensors unpacked from `elems` along the first - dimension, from first to last. + A tensor or (possibly nested) sequence of tensors. Each tensor stacks the + results of applying `fn` to tensors unstacked from `elems` along the first + dimension, from first to last. The result may include ragged and sparse + tensors. Raises: TypeError: if `fn` is not callable or the structure of the output of - `fn` and `dtype` do not match, or if elems is a SparseTensor. - ValueError: if the lengths of the output of `fn` and `dtype` do not match. + `fn` and `fn_output_signature` do not match. + ValueError: if the lengths of the output of `fn` and `fn_output_signature` + do not match. Examples: - ```python - elems = np.array([1, 2, 3, 4, 5, 6]) - squares = map_fn(lambda x: x * x, elems) - # squares == [1, 4, 9, 16, 25, 36] - ``` - ```python - elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) - alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) - # alternate == [-1, 2, -3] - ``` + >>> elems = np.array([1, 2, 3, 4, 5, 6]) + >>> tf.map_fn(lambda x: x * x, elems) + - ```python - elems = np.array([1, 2, 3]) - alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) - # alternates[0] == [1, 2, 3] - # alternates[1] == [-1, -2, -3] - ``` + >>> elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) + >>> tf.map_fn(lambda x: x[0] * x[1], elems, fn_output_signature=tf.int64) + + + >>> elems = np.array([1, 2, 3]) + >>> tf.map_fn(lambda x: (x, -x), elems, + ... fn_output_signature=(tf.int64, tf.int64)) + (, + ) """ + # This function uses a `while_loop` to call `fn` on each value of the input + # tensor(s) (unstacked on dimension 0). The following sequence of variables + # are used to transform the input tensor(s) (`elems`) into the output + # tensor(s) (`result`): + # + # - Preparing and unstacking input values for the while_loop: + # - elems: The input tensor(s) to map_fn. May include composite tensors. + # - elems_flat: Flattened list of tensors from elems (using nest.flatten) + # May include composite tensors. + # - elems_batchable: Concatenation of "batchable tensor lists" for each + # tensor in elems_flat. This "boxes" composite tensors + # into sliceable tf.Tensor objects. For more info see: + # TensorSpec._to_batched_tensor_list + # - elems_batchable_ta: List of TensorArrays used to unstack each Tensor + # in elems_batchable into elems_value_batchable. + # + # - Calling `fn` on each unstacked value in the body of the while_loop: + # - elems_value_batchable: Single unstacked value from elems_batchable. + # - elems_value_flat: Single unstacked value from elems_flat, + # constructed from elems_value_batchable (using + # TensorSpec._from_tensor_list). + # - elems_value: Single unstacked value from elems (the input to fn). + # - result_value: Result of calling `fn(elems_value)`. May contain + # composite tensors. + # - result_value_flat: Flattened list of tensors from result_value. + # May contain composite tensors. + # - result_value_batchable: Concatenation of batchable tensor lists for + # each tensor in result_value_flat + # (using TensorSpec._to_tensor_list). + # + # - Collecting and stacking output values from the while_loop: + # - result_batchable_ta: List of TensorArrays used to stack each tensor + # ta result_value_batchable into result_batchable. + # - result_batchable: Stacked tensors from result_batchable_ta. + # - result_flat: Flat list of tensors for the result, constructed from + # results bactchable (using TensorSpec._from_tensor_list). + # - result: Structured result value packed from results flat + # (using nest.pack_sequence_as). + + if fn_output_signature is None: + fn_output_signature = dtype + if not callable(fn): raise TypeError("fn must be callable.") - if isinstance(elems, sparse_tensor.SparseTensor): - raise TypeError( - "To perform a map on the values of a sparse tensor use either " - " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " - " SparseTensor(input.indices, map_fn(fn, input.values), " - "input.dense_shape)") - in_graph_mode = not context.executing_eagerly() # Set the default number of parallel_iterations depending on graph/eager mode. if in_graph_mode and not parallel_iterations: parallel_iterations = 10 elif not in_graph_mode and not parallel_iterations: parallel_iterations = 1 - - if not in_graph_mode and parallel_iterations > 1: + elif not in_graph_mode and parallel_iterations > 1: logging.log_first_n( logging.WARN, "Setting parallel_iterations > 1 has no " "effect when executing eagerly. Consider calling map_fn" @@ -171,23 +366,25 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, "parallel.", 1) parallel_iterations = 1 - input_is_sequence = nest.is_sequence(elems) - input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] - def input_pack(x): - return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] + # Flatten the input tensors, and get the TypeSpec for each one. + elems_flat = nest.flatten(elems) + elems_flat_signature = [type_spec.type_spec_from_value(e) for e in elems_flat] + elems_unflatten = lambda x: nest.pack_sequence_as(elems, x) - if dtype is None: - output_is_sequence = input_is_sequence - output_flatten = input_flatten - output_pack = input_pack + # Flatten fn's output signature. + if fn_output_signature is None: + # If fn_output_signature was not specified, then assume that it matches the + # input signature. + result_flat_signature = [ + _most_general_compatible_type(s)._unbatch() # pylint: disable=protected-access + for s in elems_flat_signature + ] + result_unflatten = elems_unflatten else: - output_is_sequence = nest.is_sequence(dtype) - output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] - def output_pack(x): - return (nest.pack_sequence_as(dtype, x) - if output_is_sequence else x[0]) - - elems_flat = input_flatten(elems) + result_flat_signature = [ + _dtype_to_spec(d) for d in nest.flatten(fn_output_signature) + ] + result_unflatten = lambda x: nest.pack_sequence_as(fn_output_signature, x) with ops.name_scope(name, "map", elems_flat): # TODO(akshayka): Remove the in_graph_mode check once caching devices are @@ -204,42 +401,56 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, varscope_caching_device_was_none = True elems_flat = [ - ops.convert_to_tensor(elem, name="elem") for elem in elems_flat] + ops.convert_to_tensor_or_composite(t, name="elem") for t in elems_flat + ] - dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) - dtype_flat = output_flatten(dtype) - - # Convert elems to tensor array. n may be known statically. - static_shape = elems_flat[0].shape - if static_shape.ndims is not None and static_shape.ndims < 1: + # Check that inputs are not scalars. + elems_static_shape = elems_flat[0].shape + if elems_static_shape.ndims is not None and elems_static_shape.ndims < 1: if len(elems_flat) == 1: raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") else: raise ValueError( "elements in elems must be 1+ dimensional Tensors, not scalars" ) - n = (tensor_shape.dimension_value(static_shape[0]) - or array_ops.shape(elems_flat[0])[0]) - # TensorArrays are always flat - elems_ta = [ - tensor_array_ops.TensorArray(dtype=elem.dtype, - size=n, - dynamic_size=False, - infer_shape=True) - for elem in elems_flat] + # Box any composite tensors into tensor lists. + elems_batchable = _elems_flat_to_batchable(elems_flat) + + # Find the number of iterations, n. (may be known statically.) + n_static = tensor_shape.Dimension( + tensor_shape.dimension_value( + elems_batchable[0].get_shape().with_rank_at_least(1)[0])) + for tensor in elems_batchable[1:]: + n_static.merge_with( + tensor_shape.Dimension( + tensor_shape.dimension_value( + tensor.get_shape().with_rank_at_least(1)[0]))) + n = n_static.value or array_ops.shape(elems_batchable[0])[0] + + # Convert elems to tensor array. + # TODO(edloper): Should we set infer_shape=False for composite tensors? + elems_batchable_ta = [ + tensor_array_ops.TensorArray( + dtype=t.dtype, size=n, dynamic_size=False, infer_shape=True) + for t in elems_batchable + ] # Unpack elements - elems_ta = [ - elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)] + elems_batchable_ta = [ + ta.unstack(t) for (ta, t) in zip(elems_batchable_ta, elems_batchable) + ] i = constant_op.constant(0) - accs_ta = [ - tensor_array_ops.TensorArray(dtype=dt, - size=n, - dynamic_size=False, - infer_shape=infer_shape) - for dt in dtype_flat] + # Prepare result tensor array. + # TODO(edloper): Should we set infer_shape=False for composite tensors? + result_batchable_dtype = _result_flat_signature_to_batchable_dtype( + result_flat_signature) + result_batchable_ta = [ + tensor_array_ops.TensorArray( + dtype=dt, size=n, dynamic_size=False, infer_shape=infer_shape) + for dt in result_batchable_dtype + ] def compute(i, tas): """The loop body of map_fn. @@ -252,30 +463,34 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, (i + 1, tas): the updated counter + updated TensorArrays Raises: - TypeError: if dtype and packed_fn_values structure do not match - ValueType: if dtype and packed_fn_values lengths do not match + TypeError: if fn_output_signature and result_value structure don't match + ValueType: if fn_output_signature and result_value lengths don't match """ - packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) - packed_fn_values = fn(packed_values) - nest.assert_same_structure(dtype or elems, packed_fn_values) - flat_fn_values = output_flatten(packed_fn_values) - tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)] + elems_value_batchable = [ta.read(i) for ta in elems_batchable_ta] + elems_value_flat = _elems_value_batchable_to_flat(elems_value_batchable, + elems_flat_signature) + elems_value = elems_unflatten(elems_value_flat) + result_value = fn(elems_value) + nest.assert_same_structure(fn_output_signature or elems, result_value) + result_value_flat = nest.flatten(result_value) + result_value_batchable = _result_value_flat_to_batchable( + result_value_flat, result_flat_signature) + tas = [ + ta.write(i, value) for (ta, value) in zip(tas, result_value_batchable) + ] return (i + 1, tas) _, r_a = control_flow_ops.while_loop( - lambda i, _: i < n, compute, (i, accs_ta), + lambda i, _: i < n, + compute, (i, result_batchable_ta), parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, maximum_iterations=n) - results_flat = [r.stack() for r in r_a] + result_batchable = [r.stack() for r in r_a] - n_static = tensor_shape.Dimension(tensor_shape.dimension_value( - elems_flat[0].get_shape().with_rank_at_least(1)[0])) - for elem in elems_flat[1:]: - n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value( - elem.get_shape().with_rank_at_least(1)[0]))) - for r in results_flat: + # Update each output tensor w/ static shape info about the outer dimension. + for r in result_batchable: r.set_shape(tensor_shape.TensorShape(n_static).concatenate( r.get_shape()[1:])) @@ -284,7 +499,103 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=None, back_prop=True, if in_graph_mode and varscope_caching_device_was_none: varscope.set_caching_device(None) - return output_pack(results_flat) + result_flat = _result_batchable_to_flat(result_batchable, + result_flat_signature) + result = result_unflatten(result_flat) + return result + + +def _dtype_to_spec(d): + if not isinstance(d, type_spec.TypeSpec): + d = tensor_spec.TensorSpec(None, d) + return d + + +def _most_general_compatible_type(spec): + """Returns the most general TypeSpec compatible with `spec`.""" + # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API + if isinstance(spec, tensor_spec.TensorSpec): + return tensor_spec.TensorSpec(None, spec.dtype) + elif isinstance(spec, ragged_tensor.RaggedTensorSpec): + # pylint: disable=protected-access + return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank, + spec._row_splits_dtype) + elif isinstance(spec, sparse_tensor.SparseTensorSpec): + # pylint: disable=protected-access + return sparse_tensor.SparseTensorSpec(None, spec.dtype) + else: + return spec + + +def _result_flat_signature_to_batchable_dtype(result_flat_signature): + """Converts result_flat_signature -> result_batchable_dtype.""" + components = [] + for spec in result_flat_signature: + if not isinstance(spec, type_spec.BatchableTypeSpec): + raise TypeError("map_fn can not generate %s outputs" % (spec,)) + # pylint: disable=protected-access + components.extend([s.dtype for s in spec._flat_tensor_specs]) + return components + + +def _elems_flat_to_batchable(elems_flat): + """Converts elems_flat -> elems_batchable.""" + elems_batchable = [] + for elems_tensor in elems_flat: + spec = type_spec.type_spec_from_value(elems_tensor) + if not isinstance(spec, type_spec.BatchableTypeSpec): + raise TypeError("map_fn can not consume %s inputs: got %r" % + (spec, elems_tensor)) + # pylint: disable=protected-access + elems_batchable.extend(spec._to_batched_tensor_list(elems_tensor)) + return elems_batchable + + +def _elems_value_batchable_to_flat(elems_value_batchable, elems_flat_signature): + """Converts elems_value_batchable -> elems_value_flat.""" + elems_value_flat = [] + i = 0 + for spec in elems_flat_signature: + # pylint: disable=protected-access + spec = spec._unbatch() + tensor_list = elems_value_batchable[i:i + len(spec._flat_tensor_specs)] + elems_value_flat.append(spec._from_compatible_tensor_list(tensor_list)) + i += len(tensor_list) + assert i == len(elems_value_batchable) + return elems_value_flat + + +def _result_value_flat_to_batchable(result_value_flat, result_flat_signature): + """Converts result_value_flat -> result_value_batchable.""" + result_value_batchable = [] + for (r_value, r_spec) in zip(result_value_flat, result_flat_signature): + if isinstance(r_spec, tensor_spec.TensorSpec): + result_value_batchable.append(r_value) + else: + if not r_spec.is_compatible_with(r_value): + raise ValueError( + "Error in map_fn:\n Expected `fn` to return a:\n %s\n" + " But it returned a:\n %s\n (value=%s)\n" + " To fix, update the `fn_output_signature` (or `dtype`) " + "argument to `map_fn`." % + (r_spec, type_spec.type_spec_from_value(r_value), r_value)) + result_value_batchable.extend(r_spec._to_tensor_list(r_value)) # pylint: disable=protected-access + return result_value_batchable + + +def _result_batchable_to_flat(result_batchable, result_flat_signature): + """Converts result_batchable -> result_flat.""" + result_flat = [] + i = 0 + for spec in result_flat_signature: + # pylint: disable=protected-access + num_tensors = len(spec._flat_tensor_specs) + result_flat.append( + spec._batch(None)._from_compatible_tensor_list( + result_batchable[i:i + num_tensors])) + i += num_tensors + assert i == len(result_batchable) + return result_flat @tf_export("map_fn", v1=[]) @@ -297,6 +608,7 @@ Use: results = tf.nest.map_structure(tf.stop_gradient, tf.map_fn(fn, elems))""", warn_once=True, back_prop=False) +@deprecation.deprecated_args(None, "Use fn_output_signature instead", "dtype") def map_fn_v2(fn, elems, dtype=None, @@ -304,122 +616,25 @@ def map_fn_v2(fn, back_prop=True, swap_memory=False, infer_shape=True, - name=None): - """map on the list of tensors unpacked from `elems` on dimension 0. - - The simplest version of `map_fn` repeatedly applies the callable `fn` to a - sequence of elements from first to last. The elements are made of the - tensors unpacked from `elems`. `dtype` is the data type of the return - value of `fn`. Users must provide `dtype` if it is different from - the data type of `elems`. - - Suppose that `elems` is unpacked into `values`, a list of tensors. The shape - of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. - - This method also allows multi-arity `elems` and output of `fn`. If `elems` - is a (possibly nested) list or tuple of tensors, then each of these tensors - must have a matching first (unpack) dimension. The signature of `fn` may - match the structure of `elems`. That is, if `elems` is - `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: - `fn = lambda (t1, [t2, t3, [t4, t5]]):`. - - Furthermore, `fn` may emit a different structure than its input. For example, - `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, - the `dtype` parameter is not optional: `dtype` must be a type or (possibly - nested) tuple of types matching the output of `fn`. - - To apply a functional operation to the nonzero elements of a SparseTensor - one of the following methods is recommended. First, if the function is - expressible as TensorFlow ops, use - - ```python - result = SparseTensor(input.indices, fn(input.values), input.dense_shape) - ``` - - If, however, the function is not expressible as a TensorFlow op, then use - - ```python - result = SparseTensor( - input.indices, map_fn(fn, input.values), input.dense_shape) - ``` - - instead. - - When executing eagerly, map_fn does not execute in parallel even if - `parallel_iterations` is set to a value > 1. You can still get the - performance benefits of running a function in parallel by using the - `tf.function` decorator, - - ```python - # Assume the function being used in map_fn is fn. - # To ensure map_fn calls fn in parallel, use the tf.function decorator. - @tf.function - def func(tensor): - return tf.map_fn(fn, tensor) - ``` - - Note that if you use the `tf.function` decorator, any non-TensorFlow Python - code that you may have written in your function won't get executed. See - [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) for - more details. The recommendation would be to debug without `tf.function` but - switch to it to get performance benefits of running `map_fn` in parallel. - - Args: - fn: The callable to be performed. It accepts one argument, which will have - the same (possibly nested) structure as `elems`. Its output must have the - same structure as `dtype` if one is provided, otherwise it must have the - same structure as `elems`. - elems: A tensor or (possibly nested) sequence of tensors, each of which will - be unpacked along their first dimension. The nested sequence of the - resulting slices will be applied to `fn`. - dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure - of Tensors differing from the structure of `elems`, then `dtype` is not - optional and must have the same structure as the output of `fn`. - parallel_iterations: (optional) The number of iterations allowed to run in - parallel. When graph building, the default value is 10. While executing - eagerly, the default value is set to 1. - back_prop: (optional) Deprecated. False disables support for back - propagation. Prefer using `tf.stop_gradient` instead. - swap_memory: (optional) True enables GPU-CPU memory swapping. - infer_shape: (optional) False disables tests for consistent output shapes. - name: (optional) Name prefix for the returned tensors. - - Returns: - A tensor or (possibly nested) sequence of tensors. Each tensor packs the - results of applying `fn` to tensors unpacked from `elems` along the first - dimension, from first to last. - - Raises: - TypeError: if `fn` is not callable or the structure of the output of - `fn` and `dtype` do not match, or if elems is a SparseTensor. - ValueError: if the lengths of the output of `fn` and `dtype` do not match. - - Examples: - ```python - elems = np.array([1, 2, 3, 4, 5, 6]) - squares = map_fn(lambda x: x * x, elems) - # squares == [1, 4, 9, 16, 25, 36] - ``` - - ```python - elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) - alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) - # alternate == [-1, 2, -3] - ``` - - ```python - elems = np.array([1, 2, 3]) - alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) - # alternates[0] == [1, 2, 3] - # alternates[1] == [-1, -2, -3] - ``` - """ + name=None, + fn_output_signature=None): + """Transform `elems` by applying `fn` to each element unstacked on axis 0.""" + if fn_output_signature is None: + fn_output_signature = dtype return map_fn( fn=fn, elems=elems, - dtype=dtype, + fn_output_signature=fn_output_signature, parallel_iterations=parallel_iterations, back_prop=back_prop, swap_memory=swap_memory, infer_shape=infer_shape, name=name) + + +# Docstring for v2 is the same as v1, except that back_prop is deprecated. +map_fn_v2.__doc__ = re.sub( + r"( back_prop: \(optional\) )(.*)", + r"\1Deprecated: prefer using `tf.stop_gradient` instead. \2", + map_fn.__doc__) +assert "prefer using `tf.stop_gradient` instead" in map_fn_v2.__doc__ diff --git a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py index 5ec6d54bc6b..c0325628e6e 100644 --- a/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_map_fn_op_test.py @@ -46,13 +46,16 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, dict( fn=mo.reduce_mean, elems=[[1, 2, 3], [4, 5], [6, 7]], + elems_dtype=dtypes.int32, expected_output=[2, 4, 6], + result_dtype=dtypes.int32, ), dict( fn=string_ops.reduce_join, elems=[['foo', 'bar', 'baz'], ['a'], ['b', 'c']], expected_output=[b'foobarbaz', b'a', b'bc'], - dtype=dtypes.string, + elems_dtype=dtypes.string, + result_dtype=dtypes.string, ), # [d1, (d2)] -> [d1, 2] dict( @@ -60,7 +63,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, # fn=self.stack_mean_and_sum, elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[2, 6], [4.5, 9], [6.5, 13]], - dtype=dtypes.float32, + elems_dtype=dtypes.float32, + result_dtype=dtypes.float32, expected_ragged_rank=0, ), # [d1, (d2)] -> [d1, (d2)] @@ -68,7 +72,7 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, fn=lambda x: x + np.int64(1), elems=[[1, 2, 3], [4, 5], [6, 7]], expected_output=[[2, 3, 4], [5, 6], [7, 8]], - dtype=dtypes.int64, + elems_dtype=dtypes.int64, result_dtype=ragged_tensor.RaggedTensorType( dtype=dtypes.int64, ragged_rank=1), ), @@ -157,11 +161,11 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, expected_ragged_rank=None, result_ragged_rank=None, elems_ragged_rank=None, - dtype=dtypes.int64, + elems_dtype=dtypes.int64, result_dtype=None, - infer_shape=False, + infer_shape=True, ): - elems = ragged_factory_ops.constant(elems, dtype, elems_ragged_rank) + elems = ragged_factory_ops.constant(elems, elems_dtype, elems_ragged_rank) output = ragged_map_ops.map_fn( fn=fn, elems=elems, dtype=result_dtype, infer_shape=infer_shape) @@ -260,8 +264,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, def testMismatchRaggedRank(self): elems = ragged_factory_ops.constant([[[1, 2, 3]], [[4, 5], [6, 7]]]) fn = lambda x: ragged_math_ops.reduce_sum(x, axis=0) - with self.assertRaisesWithLiteralMatch( - ValueError, r'The declared ragged rank (23) mismatches the result (1)'): + with self.assertRaisesRegexp( + ValueError, r'(?s)Expected `fn` to return.*But it returned.*'): _ = ragged_map_ops.map_fn( fn, elems, @@ -271,8 +275,8 @@ class RaggedMapOpTest(test_util.TensorFlowTestCase, def testMismatchRaggedRank2(self): elems = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [6, 7]]) fn = lambda x: ragged_tensor.RaggedTensor.from_row_starts(x, [0]) - with self.assertRaisesWithLiteralMatch( - ValueError, r'The declared ragged rank (10) mismatches the result (2)'): + with self.assertRaisesRegexp( + ValueError, r'(?s)Expected `fn` to return.*But it returned.*'): _ = ragged_map_ops.map_fn( fn, elems, diff --git a/tensorflow/python/ops/ragged/ragged_map_ops.py b/tensorflow/python/ops/ragged/ragged_map_ops.py index 64bae498b31..69d10529685 100644 --- a/tensorflow/python/ops/ragged/ragged_map_ops.py +++ b/tensorflow/python/ops/ragged/ragged_map_ops.py @@ -17,22 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import tensor_array_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.ops.ragged import ragged_config from tensorflow.python.ops.ragged import ragged_tensor -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util.lazy_loader import LazyLoader + + +map_fn_lib = LazyLoader( + "map_fn_lib", globals(), + "tensorflow.python.ops.map_fn") def map_fn(fn, @@ -166,298 +158,25 @@ def map_fn(fn, # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) ``` """ - if not callable(fn): - raise TypeError("fn must be callable.") - - if isinstance(elems, sparse_tensor.SparseTensor): - raise TypeError( - "To perform a map on the values of a sparse tensor use either " - " SparseTensor(input.indices, fn(input.values), input.dense_shape) or " - " SparseTensor(input.indices, map_fn(fn, input.values), " - "input.dense_shape)") - - in_graph_mode = not context.executing_eagerly() - # Set the default number of parallel_iterations depending on graph/eager mode. - if in_graph_mode and not parallel_iterations: - parallel_iterations = 10 - elif not in_graph_mode and not parallel_iterations: - parallel_iterations = 1 - - if not in_graph_mode and parallel_iterations > 1: - logging.log_first_n(logging.WARN, "Setting parallel_iterations > 1 has no " - "effect when executing eagerly. Consider calling map_fn" - " with tf.contrib.eager.defun to execute fn in " - "parallel.", 1) - parallel_iterations = 1 - - input_is_sequence = nest.is_sequence(elems) - input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] - - def input_pack(x): - return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] - - elems_flat = input_flatten(elems) - elems_flat = ragged_tensor.match_row_splits_dtypes(*elems_flat) - - with ops.name_scope(name, "map", elems_flat): - # TODO(akshayka): Remove the in_graph_mode check once caching devices are - # supported in Eager - if in_graph_mode: - # Any get_variable calls in fn will cache the first call locally - # and not issue repeated network I/O requests for each iteration. - varscope = vs.get_variable_scope() - varscope_caching_device_was_none = False - if varscope.caching_device is None: - # TODO(ebrevdo): Change to using colocate_with here and in other - # methods. - varscope.set_caching_device(lambda op: op.device) - varscope_caching_device_was_none = True - - elems_flat = [ - ragged_tensor.convert_to_tensor_or_ragged_tensor(elem, name="elem") - for elem in elems_flat - ] - - # We can either infer the output, or we can assume that it will be the same - # as the input structure. - dtype = dtype or input_pack([elem.dtype for elem in elems_flat]) - - # Find the number of iterations, n may be known statically. - if isinstance(elems_flat[0], ragged_tensor.RaggedTensor): - n = elems_flat[0].nrows(out_type=dtypes.int32) - else: - static_shape = elems_flat[0].shape - if static_shape.ndims is not None and static_shape.ndims < 1: - if len(elems_flat) == 1: - raise ValueError( - "elems must be a 1+ dimensional Tensor, not a scalar") - else: - raise ValueError( - "elements in elems must be 1+ dimensional Tensors, not scalars") - n = (tensor_shape.dimension_value(static_shape[0]) or - array_ops.shape(elems_flat[0])[0]) - - n = math_ops.cast(n, dtype=dtypes.int32) - # Create a flat list of TAs. - - # Flatten the dtype structure to a list. - dtype_flat = nest.flatten(dtype) - - # decompose to components - dtype_components = [_maybe_decompose_dtype(d) for d in dtype_flat] - dtype_components_flat = nest.flatten(dtype_components) - - # Create TensorArrays. - accs_ta = [ - tensor_array_ops.TensorArray( - dtype=t, dynamic_size=False, infer_shape=infer_shape, size=n) - for t in dtype_components_flat - ] - - i = constant_op.constant(0, dtype=dtypes.int32) - - def compute(i, tas): - """The loop body of map_fn. - - Args: - i: the loop counter - tas: the flat TensorArray accumulator list - - Returns: - (i + 1, tas): the updated counter + updated TensorArrays - - Raises: - TypeError: if dtype and packed_fn_values structure do not match - ValueType: if dtype and packed_fn_values lengths do not match - """ - # Get Tensors or RaggedTensors sliced at i, then pack it back to the - # original structure. - packed_values = input_pack([elem_flat[i] for elem_flat in elems_flat]) - packed_fn_values = fn(packed_values) - - # Check that the structure of the output matches what was declared or - # inferred. - # nest.assert_same_structure(dtype or elems, packed_fn_values) - - # Flatten and decompose to a list of Tensors - flat_fn_values = nest.flatten(packed_fn_values) - - # If we declared that we are expecting a RaggedTensor output, but we get a - # Tensor output. We should try to convert it to a RaggedTensor. - flat_fn_composite_tensors = list( - _convert_declared(flat_fn_values, dtype_flat)) - - flat_fn_components = [ - _maybe_decompose_tensor(t) for t in flat_fn_composite_tensors - ] - flat_fn_tensors = nest.flatten(flat_fn_components) - - # Write to TAs. - tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_tensors)] - - return (i + 1, tas) - - _, r_a = control_flow_ops.while_loop( - lambda i, _: i < n, compute, (i, accs_ta), - parallel_iterations=parallel_iterations, - back_prop=back_prop, - swap_memory=swap_memory, - maximum_iterations=n) - - # TODO(akshayka): Remove the in_graph_mode check once caching devices are - # supported in Eager - if in_graph_mode and varscope_caching_device_was_none: - varscope.set_caching_device(None) - - # Pack back into a list of components - results_as_components = nest.pack_sequence_as(dtype_components, r_a) - - # Stack TensorArrays for Tensor outputs, and concat RaggedTensor outputs. - def _stack_or_concat(e): - if isinstance(e, _RaggedTensorComponents): - return _concat_ragged_tensor_components(e) - else: - result = e.stack() - return result - - results_flat_components = [ - _stack_or_concat(e) for e in results_as_components - ] - - results_packed = [ - _maybe_recompose_tensor(c) for c in results_flat_components - ] - results_packed = nest.pack_sequence_as(dtype, results_packed) - return results_packed + if dtype is None: + dtype = nest.map_structure(lambda e: e.dtype, elems) + dtype = nest.map_structure(_ragged_type_to_spec, dtype) + return map_fn_lib.map_fn(fn, + elems, + dtype, + parallel_iterations, + back_prop, + swap_memory, + infer_shape, + name) -class _RaggedTensorComponents( - collections.namedtuple( - "_RaggedTensorComponents", - ["flat_values", "nested_row_lengths", "outer_row_length"])): - """A namedtuple of components which represent a `RaggedTensor`. - - _RaggedTensorComponents is a list of components which can be used to create a - `RaggedTensor`. Use this class to represent a `RaggedTensor` in situations - where nest.flatten and nest.pack_sequence_as should decompose ragged tensors - into their components.. - - The following are a list of components for a `RaggedTensor`: - - flat_values: The flat and inner values of a RaggedTensor. This could be - a `Tensor`, a `TensorArray`, or a data type. - nested_row_lengths: a tuple containing the row lengths of each rank. The - elements of the tuple could be `Tensor`s or `TensorArray`s. - outer_row_length: a `Tensor` or `TensorArray` containing the row length of the - `RaggedTensor`'s outermost dimension. - - See `RaggedTensor` for more details of the use of each component. - """ - __slots__ = () - - -def _concat_ragged_tensor_components(rt_ta): - flat_values = rt_ta.flat_values.concat() - nested_row_lengths = tuple( - row_lengths_ta.concat() for row_lengths_ta in rt_ta.nested_row_lengths) - outer_row_length = rt_ta.outer_row_length.concat() - return _RaggedTensorComponents( - flat_values=flat_values, - nested_row_lengths=nested_row_lengths, - outer_row_length=outer_row_length) - - -def _maybe_decompose_tensor(rt): - """Decompose tensors to their composite tensors.""" - if not isinstance(rt, ragged_tensor.RaggedTensor): - return rt - - # The three component pieces we need: - # - inner values - flat_values = rt.flat_values - - # - row_splits of the RT - splits = rt.nested_row_splits - nested_row_lengths = tuple(split[1:] - split[:-1] for split in splits) - - # - outer row length - outer_row_length = array_ops.expand_dims(rt.nrows(), axis=0) - - return _RaggedTensorComponents( - flat_values=flat_values, - nested_row_lengths=nested_row_lengths, - outer_row_length=outer_row_length, - ) - - -def _maybe_recompose_tensor(t): - """Reconstructs a _RaggedTensorComponents into a RaggedTensor.""" - if not isinstance(t, _RaggedTensorComponents): - return t - - values = t.flat_values - nested_row_lengths = tuple(t.nested_row_lengths) - for nested_row_length in reversed(nested_row_lengths): - values = ragged_tensor.RaggedTensor.from_row_lengths( - values, nested_row_length, validate=False) - return ragged_tensor.RaggedTensor.from_row_lengths(values, t.outer_row_length, - validate=False) - - -def _maybe_decompose_dtype(d): - """Decompose dtypes into composite tensors (if necessary).""" - if not isinstance(d, ragged_tensor.RaggedTensorType): - return d - - result = _RaggedTensorComponents( - flat_values=d.dtype, - nested_row_lengths=tuple( - d.row_splits_dtype for i in range(d.ragged_rank - 1)), - outer_row_length=d.row_splits_dtype, - ) - return result - - -def _convert_declared(fn_output_flat, output_declared): - """Convert outputs which are `Tensor`s into `_RaggedTensorComponents`.""" - for current, declared in zip(fn_output_flat, output_declared): - if isinstance(declared, ragged_tensor.RaggedTensorType): - yield _convert_declared_ragged(current, declared) - else: - yield current - - -def _convert_declared_ragged(current, declared): - """Converts an output with RaggedTensorType into a _RaggedTensorComponents.""" - # Check that the ragged ranks match up. - # + 1 to account for the rank of the outermost dimension. - current_ragged_rank = getattr(current, "ragged_rank", 0) - if declared.ragged_rank != current_ragged_rank + 1: - raise ValueError( - "The declared ragged rank (%d) mismatches the result (%d)" % - (declared.ragged_rank, current_ragged_rank + 1)) - - # Check that dtypes match up. - if declared.dtype != current.dtype: - raise ValueError( - "The declared dtype (%s) mismatches the result (%s)" % - (declared.dtype, current.dtype)) - if (isinstance(current, ragged_tensor.RaggedTensor) and - declared.row_splits_dtype != current.row_splits.dtype): - if not ragged_config.auto_cast_partition_dtype(): - raise ValueError( - "The declared row_splits dtype (%s) mismatches the result (%s)." - " Use RaggedTensor.with_row_splits_dtype to convert it." - % (declared.row_splits_dtype, current.row_splits.dtype)) - current = current.with_row_splits_dtype(declared.row_splits_dtype) - - if isinstance(current, ragged_tensor.RaggedTensor): - return current +def _ragged_type_to_spec(t): + if isinstance(t, ragged_tensor.RaggedTensorType): + # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the + # type for the mapped `fn` output, but RaggedTensorType gives the type for + # the result of stacking the mapped `fn` outputs. + return ragged_tensor.RaggedTensorSpec( + None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype) else: - nrows = array_ops.shape(current, out_type=declared.row_splits_dtype)[0] - row_length = array_ops.expand_dims(nrows, axis=0) - return _RaggedTensorComponents( - flat_values=current, - nested_row_lengths=(), - outer_row_length=row_length) - + return t diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 33aa56873b0..050d55be118 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -2330,6 +2330,10 @@ class RaggedTensorType(object): ragged_rank = property(lambda self: self._ragged_rank) row_splits_dtype = property(lambda self: self._row_splits_dtype) + def __repr__(self): + return "RaggedTensorType(%r, %r, %r)" % ( + self.dtype, self.ragged_rank, self.row_splits_dtype) + #=============================================================================== # Helper Functions diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 2f7c4e8bbd3..70f60e5cb92 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1626,7 +1626,7 @@ tf_module { } member_method { name: "map_fn" - argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], " } member_method { name: "matching_files" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index c56730870eb..a5200f86bfa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -786,7 +786,7 @@ tf_module { } member_method { name: "map_fn" - argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\'], " + argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\', \'fn_output_signature\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'False\', \'True\', \'None\', \'None\'], " } member_method { name: "matmul"