Preserve static information about outer dimension size in tf.map_fn with ragged/sparse outputs.
PiperOrigin-RevId: 333739750 Change-Id: Iedd8c18c0ba7b94f1487365e52f14a8926c92f90
This commit is contained in:
parent
eb0f1eda6d
commit
ba73e805fe
tensorflow/python
@ -20,8 +20,8 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
@ -33,6 +33,8 @@ from tensorflow.python.ops import map_fn
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -79,6 +81,17 @@ class MapFnTest(test.TestCase):
|
||||
self.assertAllEqual(result.values, st.values)
|
||||
self.assertAllEqual(result.dense_shape, st.dense_shape)
|
||||
|
||||
def testMapRaggedTensor(self):
|
||||
# Note: there are additional tests in ragged/ragged_map_fn_op_test.py
|
||||
with self.cached_session():
|
||||
rt = ragged_factory_ops.constant([[1, 2], [3]])
|
||||
result = map_fn.map_fn(
|
||||
lambda x: x + 1,
|
||||
rt,
|
||||
fn_output_signature=ragged_tensor.RaggedTensorSpec([None], rt.dtype))
|
||||
self.assertAllEqual([[2, 3], [4]], result)
|
||||
self.assertEqual([2, None], result.shape.as_list())
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testMapOverScalarErrors(self):
|
||||
with self.assertRaisesRegex(ValueError, "not scalars"):
|
||||
|
@ -526,7 +526,8 @@ def map_fn(fn,
|
||||
varscope.set_caching_device(None)
|
||||
|
||||
result_flat = _result_batchable_to_flat(result_batchable,
|
||||
result_flat_signature)
|
||||
result_flat_signature,
|
||||
n_static)
|
||||
result = result_unflatten(result_flat)
|
||||
return result
|
||||
|
||||
@ -608,7 +609,8 @@ def _result_value_flat_to_batchable(result_value_flat, result_flat_signature):
|
||||
return result_value_batchable
|
||||
|
||||
|
||||
def _result_batchable_to_flat(result_batchable, result_flat_signature):
|
||||
def _result_batchable_to_flat(result_batchable, result_flat_signature,
|
||||
batch_size):
|
||||
"""Converts result_batchable -> result_flat."""
|
||||
result_flat = []
|
||||
i = 0
|
||||
@ -616,7 +618,7 @@ def _result_batchable_to_flat(result_batchable, result_flat_signature):
|
||||
# pylint: disable=protected-access
|
||||
num_tensors = len(spec._flat_tensor_specs)
|
||||
result_flat.append(
|
||||
spec._batch(None)._from_compatible_tensor_list(
|
||||
spec._batch(batch_size)._from_compatible_tensor_list(
|
||||
result_batchable[i:i + num_tensors]))
|
||||
i += num_tensors
|
||||
assert i == len(result_batchable)
|
||||
|
Loading…
Reference in New Issue
Block a user