Preserve static information about outer dimension size in tf.map_fn with ragged/sparse outputs.

PiperOrigin-RevId: 333739750
Change-Id: Iedd8c18c0ba7b94f1487365e52f14a8926c92f90
This commit is contained in:
Edward Loper 2020-09-25 08:39:56 -07:00 committed by TensorFlower Gardener
parent eb0f1eda6d
commit ba73e805fe
2 changed files with 19 additions and 4 deletions
tensorflow/python

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.eager import def_function
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
@ -33,6 +33,8 @@ from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
@ -79,6 +81,17 @@ class MapFnTest(test.TestCase):
self.assertAllEqual(result.values, st.values)
self.assertAllEqual(result.dense_shape, st.dense_shape)
def testMapRaggedTensor(self):
# Note: there are additional tests in ragged/ragged_map_fn_op_test.py
with self.cached_session():
rt = ragged_factory_ops.constant([[1, 2], [3]])
result = map_fn.map_fn(
lambda x: x + 1,
rt,
fn_output_signature=ragged_tensor.RaggedTensorSpec([None], rt.dtype))
self.assertAllEqual([[2, 3], [4]], result)
self.assertEqual([2, None], result.shape.as_list())
@test_util.run_in_graph_and_eager_modes
def testMapOverScalarErrors(self):
with self.assertRaisesRegex(ValueError, "not scalars"):

View File

@ -526,7 +526,8 @@ def map_fn(fn,
varscope.set_caching_device(None)
result_flat = _result_batchable_to_flat(result_batchable,
result_flat_signature)
result_flat_signature,
n_static)
result = result_unflatten(result_flat)
return result
@ -608,7 +609,8 @@ def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
return result_value_batchable
def _result_batchable_to_flat(result_batchable, result_flat_signature):
def _result_batchable_to_flat(result_batchable, result_flat_signature,
batch_size):
"""Converts result_batchable -> result_flat."""
result_flat = []
i = 0
@ -616,7 +618,7 @@ def _result_batchable_to_flat(result_batchable, result_flat_signature):
# pylint: disable=protected-access
num_tensors = len(spec._flat_tensor_specs)
result_flat.append(
spec._batch(None)._from_compatible_tensor_list(
spec._batch(batch_size)._from_compatible_tensor_list(
result_batchable[i:i + num_tensors]))
i += num_tensors
assert i == len(result_batchable)